diff --git a/build.gradle b/build.gradle index b9c539f56..ab170a9f4 100644 --- a/build.gradle +++ b/build.gradle @@ -199,9 +199,9 @@ ext { } opensearchplugin { - name 'opensearch-anomaly-detection' - description 'OpenSearch anomaly detector plugin' - classname 'org.opensearch.ad.AnomalyDetectorPlugin' + name 'opensearch-time-series-analytics' + description 'OpenSearch time series analytics plugin' + classname 'org.opensearch.timeseries.TimeSeriesAnalyticsPlugin' extendedPlugins = ['lang-painless', 'opensearch-job-scheduler'] } @@ -655,7 +655,7 @@ task release(type: Copy, group: 'build') { List jacocoExclusions = [ // code for configuration, settings, etc is excluded from coverage - 'org.opensearch.ad.AnomalyDetectorPlugin', + 'org.opensearch.timeseries.TimeSeriesAnalyticsPlugin', // rest layer is tested in integration testing mostly, difficult to mock all of it 'org.opensearch.ad.rest.*', diff --git a/src/main/java/org/opensearch/ad/AbstractProfileRunner.java-e b/src/main/java/org/opensearch/ad/AbstractProfileRunner.java-e new file mode 100644 index 000000000..e402a4da1 --- /dev/null +++ b/src/main/java/org/opensearch/ad/AbstractProfileRunner.java-e @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import java.util.Locale; + +import org.opensearch.ad.model.InitProgressProfile; + +public abstract class AbstractProfileRunner { + protected long requiredSamples; + + public AbstractProfileRunner(long requiredSamples) { + this.requiredSamples = requiredSamples; + } + + protected InitProgressProfile computeInitProgressProfile(long totalUpdates, long intervalMins) { + float percent = Math.min((100.0f * totalUpdates) / requiredSamples, 100.0f); + int neededPoints = (int) (requiredSamples - totalUpdates); + return new InitProgressProfile( + // rounding: 93.456 => 93%, 93.556 => 94% + // Without Locale.ROOT, sometimes conversions use localized decimal digits + // rather than the usual ASCII digits. See https://tinyurl.com/y5sdr5tp + String.format(Locale.ROOT, "%.0f%%", percent), + intervalMins * neededPoints, + neededPoints + ); + } +} diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java b/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java index fa7682b5e..7c2427847 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java @@ -13,8 +13,8 @@ import static org.opensearch.action.DocWriteResponse.Result.CREATED; import static org.opensearch.action.DocWriteResponse.Result.UPDATED; -import static org.opensearch.ad.AnomalyDetectorPlugin.AD_THREAD_POOL_NAME; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME; import static org.opensearch.timeseries.util.RestHandlerUtils.XCONTENT_WITH_TYPE; import java.io.IOException; diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java-e b/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java-e new file mode 100644 index 000000000..7c2427847 --- /dev/null +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java-e @@ -0,0 +1,680 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import static org.opensearch.action.DocWriteResponse.Result.CREATED; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME; +import static org.opensearch.timeseries.util.RestHandlerUtils.XCONTENT_WITH_TYPE; + +import java.io.IOException; +import java.time.Instant; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTaskState; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.AnomalyResultAction; +import org.opensearch.ad.transport.AnomalyResultRequest; +import org.opensearch.ad.transport.AnomalyResultResponse; +import org.opensearch.ad.transport.AnomalyResultTransportAction; +import org.opensearch.ad.util.SecurityUtil; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.InjectSecurity; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.jobscheduler.spi.JobExecutionContext; +import org.opensearch.jobscheduler.spi.LockModel; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.jobscheduler.spi.ScheduledJobRunner; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.jobscheduler.spi.utils.LockService; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.InternalFailure; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.ExecutorFunction; + +import com.google.common.base.Throwables; + +/** + * JobScheduler will call AD job runner to get anomaly result periodically + */ +public class AnomalyDetectorJobRunner implements ScheduledJobRunner { + private static final Logger log = LogManager.getLogger(AnomalyDetectorJobRunner.class); + private static AnomalyDetectorJobRunner INSTANCE; + private Settings settings; + private int maxRetryForEndRunException; + private Client client; + private ThreadPool threadPool; + private ConcurrentHashMap detectorEndRunExceptionCount; + private ADIndexManagement anomalyDetectionIndices; + private ADTaskManager adTaskManager; + private NodeStateManager nodeStateManager; + private ExecuteADResultResponseRecorder recorder; + + public static AnomalyDetectorJobRunner getJobRunnerInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (AnomalyDetectorJobRunner.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new AnomalyDetectorJobRunner(); + return INSTANCE; + } + } + + private AnomalyDetectorJobRunner() { + // Singleton class, use getJobRunnerInstance method instead of constructor + this.detectorEndRunExceptionCount = new ConcurrentHashMap<>(); + } + + public void setClient(Client client) { + this.client = client; + } + + public void setThreadPool(ThreadPool threadPool) { + this.threadPool = threadPool; + } + + public void setSettings(Settings settings) { + this.settings = settings; + this.maxRetryForEndRunException = AnomalyDetectorSettings.MAX_RETRY_FOR_END_RUN_EXCEPTION.get(settings); + } + + public void setAdTaskManager(ADTaskManager adTaskManager) { + this.adTaskManager = adTaskManager; + } + + public void setAnomalyDetectionIndices(ADIndexManagement anomalyDetectionIndices) { + this.anomalyDetectionIndices = anomalyDetectionIndices; + } + + public void setNodeStateManager(NodeStateManager nodeStateManager) { + this.nodeStateManager = nodeStateManager; + } + + public void setExecuteADResultResponseRecorder(ExecuteADResultResponseRecorder recorder) { + this.recorder = recorder; + } + + @Override + public void runJob(ScheduledJobParameter scheduledJobParameter, JobExecutionContext context) { + String detectorId = scheduledJobParameter.getName(); + log.info("Start to run AD job {}", detectorId); + adTaskManager.refreshRealtimeJobRunTime(detectorId); + if (!(scheduledJobParameter instanceof AnomalyDetectorJob)) { + throw new IllegalArgumentException( + "Job parameter is not instance of AnomalyDetectorJob, type: " + scheduledJobParameter.getClass().getCanonicalName() + ); + } + AnomalyDetectorJob jobParameter = (AnomalyDetectorJob) scheduledJobParameter; + Instant executionStartTime = Instant.now(); + IntervalSchedule schedule = (IntervalSchedule) jobParameter.getSchedule(); + Instant detectionStartTime = executionStartTime.minus(schedule.getInterval(), schedule.getUnit()); + + final LockService lockService = context.getLockService(); + + Runnable runnable = () -> { + try { + nodeStateManager.getAnomalyDetector(detectorId, ActionListener.wrap(detectorOptional -> { + if (!detectorOptional.isPresent()) { + log.error(new ParameterizedMessage("fail to get detector [{}]", detectorId)); + return; + } + AnomalyDetector detector = detectorOptional.get(); + + if (jobParameter.getLockDurationSeconds() != null) { + lockService + .acquireLock( + jobParameter, + context, + ActionListener + .wrap( + lock -> runAdJob( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + recorder, + detector + ), + exception -> { + indexAnomalyResultException( + jobParameter, + lockService, + null, + detectionStartTime, + executionStartTime, + exception, + false, + recorder, + detector + ); + throw new IllegalStateException("Failed to acquire lock for AD job: " + detectorId); + } + ) + ); + } else { + log.warn("Can't get lock for AD job: " + detectorId); + } + + }, e -> log.error(new ParameterizedMessage("fail to get detector [{}]", detectorId), e))); + } catch (Exception e) { + // os log won't show anything if there is an exception happens (maybe due to running on a ExecutorService) + // we at least log the error. + log.error("Can't start AD job: " + detectorId, e); + throw e; + } + }; + + ExecutorService executor = threadPool.executor(AD_THREAD_POOL_NAME); + executor.submit(runnable); + } + + /** + * Get anomaly result, index result or handle exception if failed. + * + * @param jobParameter scheduled job parameter + * @param lockService lock service + * @param lock lock to run job + * @param detectionStartTime detection start time + * @param executionStartTime detection end time + * @param recorder utility to record job execution result + * @param detector associated detector accessor + */ + protected void runAdJob( + AnomalyDetectorJob jobParameter, + LockService lockService, + LockModel lock, + Instant detectionStartTime, + Instant executionStartTime, + ExecuteADResultResponseRecorder recorder, + AnomalyDetector detector + ) { + String detectorId = jobParameter.getName(); + if (lock == null) { + indexAnomalyResultException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + "Can't run AD job due to null lock", + false, + recorder, + detector + ); + return; + } + anomalyDetectionIndices.update(); + + User userInfo = SecurityUtil.getUserFromJob(jobParameter, settings); + + String user = userInfo.getName(); + List roles = userInfo.getRoles(); + + String resultIndex = jobParameter.getCustomResultIndex(); + if (resultIndex == null) { + runAnomalyDetectionJob( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + detectorId, + user, + roles, + recorder, + detector + ); + return; + } + ActionListener listener = ActionListener.wrap(r -> { log.debug("Custom index is valid"); }, e -> { + Exception exception = new EndRunException(detectorId, e.getMessage(), true); + handleAdException(jobParameter, lockService, lock, detectionStartTime, executionStartTime, exception, recorder, detector); + }); + anomalyDetectionIndices.validateCustomIndexForBackendJob(resultIndex, detectorId, user, roles, () -> { + listener.onResponse(true); + runAnomalyDetectionJob( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + detectorId, + user, + roles, + recorder, + detector + ); + }, listener); + } + + private void runAnomalyDetectionJob( + AnomalyDetectorJob jobParameter, + LockService lockService, + LockModel lock, + Instant detectionStartTime, + Instant executionStartTime, + String detectorId, + String user, + List roles, + ExecuteADResultResponseRecorder recorder, + AnomalyDetector detector + ) { + // using one thread in the write threadpool + try (InjectSecurity injectSecurity = new InjectSecurity(detectorId, settings, client.threadPool().getThreadContext())) { + // Injecting user role to verify if the user has permissions for our API. + injectSecurity.inject(user, roles); + + AnomalyResultRequest request = new AnomalyResultRequest( + detectorId, + detectionStartTime.toEpochMilli(), + executionStartTime.toEpochMilli() + ); + client + .execute( + AnomalyResultAction.INSTANCE, + request, + ActionListener + .wrap( + response -> { + indexAnomalyResult( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + response, + recorder, + detector + ); + }, + exception -> { + handleAdException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + exception, + recorder, + detector + ); + } + ) + ); + } catch (Exception e) { + indexAnomalyResultException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + e, + true, + recorder, + detector + ); + log.error("Failed to execute AD job " + detectorId, e); + } + } + + /** + * Handle exception from anomaly result action. + * + * 1. If exception is {@link EndRunException} + * a). if isEndNow == true, stop AD job and store exception in anomaly result + * b). if isEndNow == false, record count of {@link EndRunException} for this + * detector. If count of {@link EndRunException} exceeds upper limit, will + * stop AD job and store exception in anomaly result; otherwise, just + * store exception in anomaly result, not stop AD job for the detector. + * + * 2. If exception is not {@link EndRunException}, decrease count of + * {@link EndRunException} for the detector and index eception in Anomaly + * result. If exception is {@link InternalFailure}, will not log exception + * stack trace as already logged in {@link AnomalyResultTransportAction}. + * + * TODO: Handle finer granularity exception such as some exception may be + * transient and retry in current job may succeed. Currently, we don't + * know which exception is transient and retryable in + * {@link AnomalyResultTransportAction}. So we don't add backoff retry + * now to avoid bring extra load to cluster, expecially the code start + * process is relatively heavy by sending out 24 queries, initializing + * models, and saving checkpoints. + * Sometimes missing anomaly and notification is not acceptable. For example, + * current detection interval is 1hour, and there should be anomaly in + * current interval, some transient exception may fail current AD job, + * so no anomaly found and user never know it. Then we start next AD job, + * maybe there is no anomaly in next 1hour, user will never know something + * wrong happened. In one word, this is some tradeoff between protecting + * our performance, user experience and what we can do currently. + * + * @param jobParameter scheduled job parameter + * @param lockService lock service + * @param lock lock to run job + * @param detectionStartTime detection start time + * @param executionStartTime detection end time + * @param exception exception + * @param recorder utility to record job execution result + * @param detector associated detector accessor + */ + protected void handleAdException( + AnomalyDetectorJob jobParameter, + LockService lockService, + LockModel lock, + Instant detectionStartTime, + Instant executionStartTime, + Exception exception, + ExecuteADResultResponseRecorder recorder, + AnomalyDetector detector + ) { + String detectorId = jobParameter.getName(); + if (exception instanceof EndRunException) { + log.error("EndRunException happened when executing anomaly result action for " + detectorId, exception); + + if (((EndRunException) exception).isEndNow()) { + // Stop AD job if EndRunException shows we should end job now. + log.info("JobRunner will stop AD job due to EndRunException for {}", detectorId); + stopAdJobForEndRunException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + (EndRunException) exception, + recorder, + detector + ); + } else { + detectorEndRunExceptionCount.compute(detectorId, (k, v) -> { + if (v == null) { + return 1; + } else { + return v + 1; + } + }); + log.info("EndRunException happened for {}", detectorId); + // if AD job failed consecutively due to EndRunException and failed times exceeds upper limit, will stop AD job + if (detectorEndRunExceptionCount.get(detectorId) > maxRetryForEndRunException) { + log + .info( + "JobRunner will stop AD job due to EndRunException retry exceeds upper limit {} for {}", + maxRetryForEndRunException, + detectorId + ); + stopAdJobForEndRunException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + (EndRunException) exception, + recorder, + detector + ); + return; + } + indexAnomalyResultException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + exception.getMessage(), + true, + recorder, + detector + ); + } + } else { + detectorEndRunExceptionCount.remove(detectorId); + if (exception instanceof InternalFailure) { + log.error("InternalFailure happened when executing anomaly result action for " + detectorId, exception); + } else { + log.error("Failed to execute anomaly result action for " + detectorId, exception); + } + indexAnomalyResultException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + exception, + true, + recorder, + detector + ); + } + } + + private void stopAdJobForEndRunException( + AnomalyDetectorJob jobParameter, + LockService lockService, + LockModel lock, + Instant detectionStartTime, + Instant executionStartTime, + EndRunException exception, + ExecuteADResultResponseRecorder recorder, + AnomalyDetector detector + ) { + String detectorId = jobParameter.getName(); + detectorEndRunExceptionCount.remove(detectorId); + String errorPrefix = exception.isEndNow() + ? "Stopped detector: " + : "Stopped detector as job failed consecutively for more than " + this.maxRetryForEndRunException + " times: "; + String error = errorPrefix + exception.getMessage(); + stopAdJob( + detectorId, + () -> indexAnomalyResultException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + error, + true, + ADTaskState.STOPPED.name(), + recorder, + detector + ) + ); + } + + private void stopAdJob(String detectorId, ExecutorFunction function) { + GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX).id(detectorId); + ActionListener listener = ActionListener.wrap(response -> { + if (response.isExists()) { + try ( + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, response.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + AnomalyDetectorJob job = AnomalyDetectorJob.parse(parser); + if (job.isEnabled()) { + AnomalyDetectorJob newJob = new AnomalyDetectorJob( + job.getName(), + job.getSchedule(), + job.getWindowDelay(), + false, + job.getEnabledTime(), + Instant.now(), + Instant.now(), + job.getLockDurationSeconds(), + job.getUser(), + job.getCustomResultIndex() + ); + IndexRequest indexRequest = new IndexRequest(CommonName.JOB_INDEX) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .source(newJob.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), XCONTENT_WITH_TYPE)) + .id(detectorId); + + client.index(indexRequest, ActionListener.wrap(indexResponse -> { + if (indexResponse != null && (indexResponse.getResult() == CREATED || indexResponse.getResult() == UPDATED)) { + log.info("AD Job was disabled by JobRunner for " + detectorId); + // function.execute(); + } else { + log.warn("Failed to disable AD job for " + detectorId); + } + }, exception -> { log.error("JobRunner failed to update AD job as disabled for " + detectorId, exception); })); + } else { + log.info("AD Job was disabled for " + detectorId); + } + } catch (IOException e) { + log.error("JobRunner failed to stop detector job " + detectorId, e); + } + } else { + log.info("AD Job was not found for " + detectorId); + } + }, exception -> log.error("JobRunner failed to get detector job " + detectorId, exception)); + + client.get(getRequest, ActionListener.runAfter(listener, () -> function.execute())); + } + + private void indexAnomalyResult( + AnomalyDetectorJob jobParameter, + LockService lockService, + LockModel lock, + Instant detectionStartTime, + Instant executionStartTime, + AnomalyResultResponse response, + ExecuteADResultResponseRecorder recorder, + AnomalyDetector detector + ) { + String detectorId = jobParameter.getName(); + detectorEndRunExceptionCount.remove(detectorId); + try { + recorder.indexAnomalyResult(detectionStartTime, executionStartTime, response, detector); + } catch (EndRunException e) { + handleAdException(jobParameter, lockService, lock, detectionStartTime, executionStartTime, e, recorder, detector); + } catch (Exception e) { + log.error("Failed to index anomaly result for " + detectorId, e); + } finally { + releaseLock(jobParameter, lockService, lock); + } + + } + + private void indexAnomalyResultException( + AnomalyDetectorJob jobParameter, + LockService lockService, + LockModel lock, + Instant detectionStartTime, + Instant executionStartTime, + Exception exception, + boolean releaseLock, + ExecuteADResultResponseRecorder recorder, + AnomalyDetector detector + ) { + try { + String errorMessage = exception instanceof TimeSeriesException + ? exception.getMessage() + : Throwables.getStackTraceAsString(exception); + indexAnomalyResultException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + errorMessage, + releaseLock, + recorder, + detector + ); + } catch (Exception e) { + log.error("Failed to index anomaly result for " + jobParameter.getName(), e); + } + } + + private void indexAnomalyResultException( + AnomalyDetectorJob jobParameter, + LockService lockService, + LockModel lock, + Instant detectionStartTime, + Instant executionStartTime, + String errorMessage, + boolean releaseLock, + ExecuteADResultResponseRecorder recorder, + AnomalyDetector detector + ) { + indexAnomalyResultException( + jobParameter, + lockService, + lock, + detectionStartTime, + executionStartTime, + errorMessage, + releaseLock, + null, + recorder, + detector + ); + } + + private void indexAnomalyResultException( + AnomalyDetectorJob jobParameter, + LockService lockService, + LockModel lock, + Instant detectionStartTime, + Instant executionStartTime, + String errorMessage, + boolean releaseLock, + String taskState, + ExecuteADResultResponseRecorder recorder, + AnomalyDetector detector + ) { + try { + recorder.indexAnomalyResultException(detectionStartTime, executionStartTime, errorMessage, taskState, detector); + } finally { + if (releaseLock) { + releaseLock(jobParameter, lockService, lock); + } + } + } + + private void releaseLock(AnomalyDetectorJob jobParameter, LockService lockService, LockModel lock) { + lockService + .release( + lock, + ActionListener + .wrap( + released -> { log.info("Released lock for AD job {}", jobParameter.getName()); }, + exception -> { log.error("Failed to release lock for AD job: " + jobParameter.getName(), exception); } + ) + ); + } +} diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java b/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java index 22ad7b369..d9d5e3f7b 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java @@ -12,9 +12,9 @@ package org.opensearch.ad; import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_PARSE_DETECTOR_MSG; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.rest.RestStatus.BAD_REQUEST; -import static org.opensearch.rest.RestStatus.INTERNAL_SERVER_ERROR; +import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; import java.util.List; diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java-e b/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java-e new file mode 100644 index 000000000..cfec646da --- /dev/null +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java-e @@ -0,0 +1,614 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_PARSE_DETECTOR_MSG; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; +import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.core.util.Throwables; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.model.DetectorProfile; +import org.opensearch.ad.model.DetectorProfileName; +import org.opensearch.ad.model.DetectorState; +import org.opensearch.ad.model.InitProgressProfile; +import org.opensearch.ad.settings.ADNumericSetting; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.ProfileAction; +import org.opensearch.ad.transport.ProfileRequest; +import org.opensearch.ad.transport.ProfileResponse; +import org.opensearch.ad.transport.RCFPollingAction; +import org.opensearch.ad.transport.RCFPollingRequest; +import org.opensearch.ad.transport.RCFPollingResponse; +import org.opensearch.ad.util.ExceptionUtil; +import org.opensearch.ad.util.MultiResponsesDelegateActionListener; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; +import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; +import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder; +import org.opensearch.search.aggregations.metrics.InternalCardinality; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.common.exception.NotSerializedExceptionName; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.transport.TransportService; + +public class AnomalyDetectorProfileRunner extends AbstractProfileRunner { + private final Logger logger = LogManager.getLogger(AnomalyDetectorProfileRunner.class); + private Client client; + private SecurityClientUtil clientUtil; + private NamedXContentRegistry xContentRegistry; + private DiscoveryNodeFilterer nodeFilter; + private final TransportService transportService; + private final ADTaskManager adTaskManager; + private final int maxTotalEntitiesToTrack; + + public AnomalyDetectorProfileRunner( + Client client, + SecurityClientUtil clientUtil, + NamedXContentRegistry xContentRegistry, + DiscoveryNodeFilterer nodeFilter, + long requiredSamples, + TransportService transportService, + ADTaskManager adTaskManager + ) { + super(requiredSamples); + this.client = client; + this.clientUtil = clientUtil; + this.xContentRegistry = xContentRegistry; + this.nodeFilter = nodeFilter; + if (requiredSamples <= 0) { + throw new IllegalArgumentException("required samples should be a positive number, but was " + requiredSamples); + } + this.transportService = transportService; + this.adTaskManager = adTaskManager; + this.maxTotalEntitiesToTrack = AnomalyDetectorSettings.MAX_TOTAL_ENTITIES_TO_TRACK; + } + + public void profile(String detectorId, ActionListener listener, Set profilesToCollect) { + if (profilesToCollect.isEmpty()) { + listener.onFailure(new IllegalArgumentException(ADCommonMessages.EMPTY_PROFILES_COLLECT)); + return; + } + calculateTotalResponsesToWait(detectorId, profilesToCollect, listener); + } + + private void calculateTotalResponsesToWait( + String detectorId, + Set profilesToCollect, + ActionListener listener + ) { + GetRequest getDetectorRequest = new GetRequest(CommonName.CONFIG_INDEX, detectorId); + client.get(getDetectorRequest, ActionListener.wrap(getDetectorResponse -> { + if (getDetectorResponse != null && getDetectorResponse.isExists()) { + try ( + XContentParser xContentParser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getDetectorResponse.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, xContentParser.nextToken(), xContentParser); + AnomalyDetector detector = AnomalyDetector.parse(xContentParser, detectorId); + prepareProfile(detector, listener, profilesToCollect); + } catch (Exception e) { + logger.error(FAIL_TO_PARSE_DETECTOR_MSG + detectorId, e); + listener.onFailure(new OpenSearchStatusException(FAIL_TO_PARSE_DETECTOR_MSG + detectorId, BAD_REQUEST)); + } + } else { + listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, BAD_REQUEST)); + } + }, exception -> { + logger.error(FAIL_TO_FIND_CONFIG_MSG + detectorId, exception); + listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, INTERNAL_SERVER_ERROR)); + })); + } + + private void prepareProfile( + AnomalyDetector detector, + ActionListener listener, + Set profilesToCollect + ) { + String detectorId = detector.getId(); + GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX, detectorId); + client.get(getRequest, ActionListener.wrap(getResponse -> { + if (getResponse != null && getResponse.isExists()) { + try ( + XContentParser parser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getResponse.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + AnomalyDetectorJob job = AnomalyDetectorJob.parse(parser); + long enabledTimeMs = job.getEnabledTime().toEpochMilli(); + + boolean isMultiEntityDetector = detector.isHighCardinality(); + + int totalResponsesToWait = 0; + if (profilesToCollect.contains(DetectorProfileName.ERROR)) { + totalResponsesToWait++; + } + + // total number of listeners we need to define. Needed by MultiResponsesDelegateActionListener to decide + // when to consolidate results and return to users + if (isMultiEntityDetector) { + if (profilesToCollect.contains(DetectorProfileName.TOTAL_ENTITIES)) { + totalResponsesToWait++; + } + if (profilesToCollect.contains(DetectorProfileName.COORDINATING_NODE) + || profilesToCollect.contains(DetectorProfileName.SHINGLE_SIZE) + || profilesToCollect.contains(DetectorProfileName.TOTAL_SIZE_IN_BYTES) + || profilesToCollect.contains(DetectorProfileName.MODELS) + || profilesToCollect.contains(DetectorProfileName.ACTIVE_ENTITIES) + || profilesToCollect.contains(DetectorProfileName.INIT_PROGRESS) + || profilesToCollect.contains(DetectorProfileName.STATE)) { + totalResponsesToWait++; + } + if (profilesToCollect.contains(DetectorProfileName.AD_TASK)) { + totalResponsesToWait++; + } + } else { + if (profilesToCollect.contains(DetectorProfileName.STATE) + || profilesToCollect.contains(DetectorProfileName.INIT_PROGRESS)) { + totalResponsesToWait++; + } + if (profilesToCollect.contains(DetectorProfileName.COORDINATING_NODE) + || profilesToCollect.contains(DetectorProfileName.SHINGLE_SIZE) + || profilesToCollect.contains(DetectorProfileName.TOTAL_SIZE_IN_BYTES) + || profilesToCollect.contains(DetectorProfileName.MODELS)) { + totalResponsesToWait++; + } + if (profilesToCollect.contains(DetectorProfileName.AD_TASK)) { + totalResponsesToWait++; + } + } + + MultiResponsesDelegateActionListener delegateListener = + new MultiResponsesDelegateActionListener( + listener, + totalResponsesToWait, + ADCommonMessages.FAIL_FETCH_ERR_MSG + detectorId, + false + ); + if (profilesToCollect.contains(DetectorProfileName.ERROR)) { + adTaskManager.getAndExecuteOnLatestDetectorLevelTask(detectorId, ADTaskType.REALTIME_TASK_TYPES, adTask -> { + DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder(); + if (adTask.isPresent()) { + long lastUpdateTimeMs = adTask.get().getLastUpdateTime().toEpochMilli(); + + // if state index hasn't been updated, we should not use the error field + // For example, before a detector is enabled, if the error message contains + // the phrase "stopped due to blah", we should not show this when the detector + // is enabled. + if (lastUpdateTimeMs > enabledTimeMs && adTask.get().getError() != null) { + profileBuilder.error(adTask.get().getError()); + } + delegateListener.onResponse(profileBuilder.build()); + } else { + // detector state for this detector does not exist + delegateListener.onResponse(profileBuilder.build()); + } + }, transportService, false, delegateListener); + } + + // total number of listeners we need to define. Needed by MultiResponsesDelegateActionListener to decide + // when to consolidate results and return to users + if (isMultiEntityDetector) { + if (profilesToCollect.contains(DetectorProfileName.TOTAL_ENTITIES)) { + profileEntityStats(delegateListener, detector); + } + if (profilesToCollect.contains(DetectorProfileName.COORDINATING_NODE) + || profilesToCollect.contains(DetectorProfileName.SHINGLE_SIZE) + || profilesToCollect.contains(DetectorProfileName.TOTAL_SIZE_IN_BYTES) + || profilesToCollect.contains(DetectorProfileName.MODELS) + || profilesToCollect.contains(DetectorProfileName.ACTIVE_ENTITIES) + || profilesToCollect.contains(DetectorProfileName.INIT_PROGRESS) + || profilesToCollect.contains(DetectorProfileName.STATE)) { + profileModels(detector, profilesToCollect, job, true, delegateListener); + } + if (profilesToCollect.contains(DetectorProfileName.AD_TASK)) { + adTaskManager.getLatestHistoricalTaskProfile(detectorId, transportService, null, delegateListener); + } + } else { + if (profilesToCollect.contains(DetectorProfileName.STATE) + || profilesToCollect.contains(DetectorProfileName.INIT_PROGRESS)) { + profileStateRelated(detector, delegateListener, job.isEnabled(), profilesToCollect); + } + if (profilesToCollect.contains(DetectorProfileName.COORDINATING_NODE) + || profilesToCollect.contains(DetectorProfileName.SHINGLE_SIZE) + || profilesToCollect.contains(DetectorProfileName.TOTAL_SIZE_IN_BYTES) + || profilesToCollect.contains(DetectorProfileName.MODELS)) { + profileModels(detector, profilesToCollect, job, false, delegateListener); + } + if (profilesToCollect.contains(DetectorProfileName.AD_TASK)) { + adTaskManager.getLatestHistoricalTaskProfile(detectorId, transportService, null, delegateListener); + } + } + + } catch (Exception e) { + logger.error(ADCommonMessages.FAIL_TO_GET_PROFILE_MSG, e); + listener.onFailure(e); + } + } else { + onGetDetectorForPrepare(detectorId, listener, profilesToCollect); + } + }, exception -> { + if (ExceptionUtil.isIndexNotAvailable(exception)) { + logger.info(exception.getMessage()); + onGetDetectorForPrepare(detectorId, listener, profilesToCollect); + } else { + logger.error(ADCommonMessages.FAIL_TO_GET_PROFILE_MSG + detectorId); + listener.onFailure(exception); + } + })); + } + + private void profileEntityStats(MultiResponsesDelegateActionListener listener, AnomalyDetector detector) { + List categoryField = detector.getCategoryFields(); + if (!detector.isHighCardinality() || categoryField.size() > ADNumericSetting.maxCategoricalFields()) { + listener.onResponse(new DetectorProfile.Builder().build()); + } else { + if (categoryField.size() == 1) { + // Run a cardinality aggregation to count the cardinality of single category fields + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + CardinalityAggregationBuilder aggBuilder = new CardinalityAggregationBuilder(ADCommonName.TOTAL_ENTITIES); + aggBuilder.field(categoryField.get(0)); + searchSourceBuilder.aggregation(aggBuilder); + + SearchRequest request = new SearchRequest(detector.getIndices().toArray(new String[0]), searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener.wrap(searchResponse -> { + Map aggMap = searchResponse.getAggregations().asMap(); + InternalCardinality totalEntities = (InternalCardinality) aggMap.get(ADCommonName.TOTAL_ENTITIES); + long value = totalEntities.getValue(); + DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder(); + DetectorProfile profile = profileBuilder.totalEntities(value).build(); + listener.onResponse(profile); + }, searchException -> { + logger.warn(ADCommonMessages.FAIL_TO_GET_TOTAL_ENTITIES + detector.getId()); + listener.onFailure(searchException); + }); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + request, + client::search, + detector.getId(), + client, + searchResponseListener + ); + } else { + // Run a composite query and count the number of buckets to decide cardinality of multiple category fields + AggregationBuilder bucketAggs = AggregationBuilders + .composite( + ADCommonName.TOTAL_ENTITIES, + detector + .getCategoryFields() + .stream() + .map(f -> new TermsValuesSourceBuilder(f).field(f)) + .collect(Collectors.toList()) + ) + .size(maxTotalEntitiesToTrack); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().aggregation(bucketAggs).trackTotalHits(false).size(0); + SearchRequest searchRequest = new SearchRequest() + .indices(detector.getIndices().toArray(new String[0])) + .source(searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener.wrap(searchResponse -> { + DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder(); + Aggregations aggs = searchResponse.getAggregations(); + if (aggs == null) { + // This would indicate some bug or some opensearch core changes that we are not aware of (we don't keep up-to-date + // with + // the large amounts of changes there). For example, they may change to if there are results return it; otherwise + // return + // null instead of an empty Aggregations as they currently do. + logger.warn("Unexpected null aggregation."); + listener.onResponse(profileBuilder.totalEntities(0L).build()); + return; + } + + Aggregation aggrResult = aggs.get(ADCommonName.TOTAL_ENTITIES); + if (aggrResult == null) { + listener.onFailure(new IllegalArgumentException("Fail to find valid aggregation result")); + return; + } + + CompositeAggregation compositeAgg = (CompositeAggregation) aggrResult; + DetectorProfile profile = profileBuilder.totalEntities(Long.valueOf(compositeAgg.getBuckets().size())).build(); + listener.onResponse(profile); + }, searchException -> { + logger.warn(ADCommonMessages.FAIL_TO_GET_TOTAL_ENTITIES + detector.getId()); + listener.onFailure(searchException); + }); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + detector.getId(), + client, + searchResponseListener + ); + } + + } + } + + private void onGetDetectorForPrepare(String detectorId, ActionListener listener, Set profiles) { + DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder(); + if (profiles.contains(DetectorProfileName.STATE)) { + profileBuilder.state(DetectorState.DISABLED); + } + if (profiles.contains(DetectorProfileName.AD_TASK)) { + adTaskManager.getLatestHistoricalTaskProfile(detectorId, transportService, profileBuilder.build(), listener); + } else { + listener.onResponse(profileBuilder.build()); + } + } + + /** + * We expect three kinds of states: + * -Disabled: if get ad job api says the job is disabled; + * -Init: if rcf model's total updates is less than required + * -Running: if neither of the above applies and no exceptions. + * @param detector anomaly detector + * @param listener listener to process the returned state or exception + * @param enabled whether the detector job is enabled or not + * @param profilesToCollect target profiles to fetch + */ + private void profileStateRelated( + AnomalyDetector detector, + MultiResponsesDelegateActionListener listener, + boolean enabled, + Set profilesToCollect + ) { + if (enabled) { + RCFPollingRequest request = new RCFPollingRequest(detector.getId()); + client.execute(RCFPollingAction.INSTANCE, request, onPollRCFUpdates(detector, profilesToCollect, listener)); + } else { + DetectorProfile.Builder builder = new DetectorProfile.Builder(); + if (profilesToCollect.contains(DetectorProfileName.STATE)) { + builder.state(DetectorState.DISABLED); + } + listener.onResponse(builder.build()); + } + } + + private void profileModels( + AnomalyDetector detector, + Set profiles, + AnomalyDetectorJob job, + boolean forMultiEntityDetector, + MultiResponsesDelegateActionListener listener + ) { + DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); + ProfileRequest profileRequest = new ProfileRequest(detector.getId(), profiles, forMultiEntityDetector, dataNodes); + client.execute(ProfileAction.INSTANCE, profileRequest, onModelResponse(detector, profiles, job, listener));// get init progress + } + + private ActionListener onModelResponse( + AnomalyDetector detector, + Set profilesToCollect, + AnomalyDetectorJob job, + MultiResponsesDelegateActionListener listener + ) { + boolean isMultientityDetector = detector.isHighCardinality(); + return ActionListener.wrap(profileResponse -> { + DetectorProfile.Builder profile = new DetectorProfile.Builder(); + if (profilesToCollect.contains(DetectorProfileName.COORDINATING_NODE)) { + profile.coordinatingNode(profileResponse.getCoordinatingNode()); + } + if (profilesToCollect.contains(DetectorProfileName.SHINGLE_SIZE)) { + profile.shingleSize(profileResponse.getShingleSize()); + } + if (profilesToCollect.contains(DetectorProfileName.TOTAL_SIZE_IN_BYTES)) { + profile.totalSizeInBytes(profileResponse.getTotalSizeInBytes()); + } + if (profilesToCollect.contains(DetectorProfileName.MODELS)) { + profile.modelProfile(profileResponse.getModelProfile()); + profile.modelCount(profileResponse.getModelCount()); + } + if (isMultientityDetector && profilesToCollect.contains(DetectorProfileName.ACTIVE_ENTITIES)) { + profile.activeEntities(profileResponse.getActiveEntities()); + } + + if (isMultientityDetector + && (profilesToCollect.contains(DetectorProfileName.INIT_PROGRESS) + || profilesToCollect.contains(DetectorProfileName.STATE))) { + profileMultiEntityDetectorStateRelated(job, profilesToCollect, profileResponse, profile, detector, listener); + } else { + listener.onResponse(profile.build()); + } + }, listener::onFailure); + } + + private void profileMultiEntityDetectorStateRelated( + AnomalyDetectorJob job, + Set profilesToCollect, + ProfileResponse profileResponse, + DetectorProfile.Builder profileBuilder, + AnomalyDetector detector, + MultiResponsesDelegateActionListener listener + ) { + if (job.isEnabled()) { + if (profileResponse.getTotalUpdates() < requiredSamples) { + // need to double check since what ProfileResponse returns is the highest priority entity currently in memory, but + // another entity might have already been initialized and sit somewhere else (in memory or on disk). + long enabledTime = job.getEnabledTime().toEpochMilli(); + long totalUpdates = profileResponse.getTotalUpdates(); + ProfileUtil + .confirmDetectorRealtimeInitStatus( + detector, + enabledTime, + client, + onInittedEver(enabledTime, profileBuilder, profilesToCollect, detector, totalUpdates, listener) + ); + } else { + createRunningStateAndInitProgress(profilesToCollect, profileBuilder); + listener.onResponse(profileBuilder.build()); + } + } else { + if (profilesToCollect.contains(DetectorProfileName.STATE)) { + profileBuilder.state(DetectorState.DISABLED); + } + listener.onResponse(profileBuilder.build()); + } + } + + private ActionListener onInittedEver( + long lastUpdateTimeMs, + DetectorProfile.Builder profileBuilder, + Set profilesToCollect, + AnomalyDetector detector, + long totalUpdates, + MultiResponsesDelegateActionListener listener + ) { + return ActionListener.wrap(searchResponse -> { + SearchHits hits = searchResponse.getHits(); + if (hits.getTotalHits().value == 0L) { + processInitResponse(detector, profilesToCollect, totalUpdates, false, profileBuilder, listener); + } else { + createRunningStateAndInitProgress(profilesToCollect, profileBuilder); + listener.onResponse(profileBuilder.build()); + } + }, exception -> { + if (ExceptionUtil.isIndexNotAvailable(exception)) { + // anomaly result index is not created yet + processInitResponse(detector, profilesToCollect, totalUpdates, false, profileBuilder, listener); + } else { + logger + .error( + "Fail to find any anomaly result with anomaly score larger than 0 after AD job enabled time for detector {}", + detector.getId() + ); + listener.onFailure(exception); + } + }); + } + + /** + * Listener for polling rcf updates through transport messaging + * @param detector anomaly detector + * @param profilesToCollect profiles to collect like state + * @param listener delegate listener + * @return Listener for polling rcf updates through transport messaging + */ + private ActionListener onPollRCFUpdates( + AnomalyDetector detector, + Set profilesToCollect, + MultiResponsesDelegateActionListener listener + ) { + return ActionListener.wrap(rcfPollResponse -> { + long totalUpdates = rcfPollResponse.getTotalUpdates(); + if (totalUpdates < requiredSamples) { + processInitResponse(detector, profilesToCollect, totalUpdates, false, new DetectorProfile.Builder(), listener); + } else { + DetectorProfile.Builder builder = new DetectorProfile.Builder(); + createRunningStateAndInitProgress(profilesToCollect, builder); + listener.onResponse(builder.build()); + } + }, exception -> { + // we will get an AnomalyDetectionException wrapping the real exception inside + Throwable cause = Throwables.getRootCause(exception); + + // exception can be a RemoteTransportException + Exception causeException = (Exception) cause; + if (ExceptionUtil + .isException( + causeException, + ResourceNotFoundException.class, + NotSerializedExceptionName.RESOURCE_NOT_FOUND_EXCEPTION_NAME_UNDERSCORE.getName() + ) + || (ExceptionUtil.isIndexNotAvailable(causeException) + && causeException.getMessage().contains(ADCommonName.CHECKPOINT_INDEX_NAME))) { + // cannot find checkpoint + // We don't want to show the estimated time remaining to initialize + // a detector before cold start finishes, where the actual + // initialization time may be much shorter if sufficient historical + // data exists. + processInitResponse(detector, profilesToCollect, 0L, true, new DetectorProfile.Builder(), listener); + } else { + logger.error(new ParameterizedMessage("Fail to get init progress through messaging for {}", detector.getId()), exception); + listener.onFailure(exception); + } + }); + } + + private void createRunningStateAndInitProgress(Set profilesToCollect, DetectorProfile.Builder builder) { + if (profilesToCollect.contains(DetectorProfileName.STATE)) { + builder.state(DetectorState.RUNNING).build(); + } + + if (profilesToCollect.contains(DetectorProfileName.INIT_PROGRESS)) { + InitProgressProfile initProgress = new InitProgressProfile("100%", 0, 0); + builder.initProgress(initProgress); + } + } + + private void processInitResponse( + AnomalyDetector detector, + Set profilesToCollect, + long totalUpdates, + boolean hideMinutesLeft, + DetectorProfile.Builder builder, + MultiResponsesDelegateActionListener listener + ) { + if (profilesToCollect.contains(DetectorProfileName.STATE)) { + builder.state(DetectorState.INIT); + } + + if (profilesToCollect.contains(DetectorProfileName.INIT_PROGRESS)) { + if (hideMinutesLeft) { + InitProgressProfile initProgress = computeInitProgressProfile(totalUpdates, 0); + builder.initProgress(initProgress); + } else { + long intervalMins = ((IntervalTimeConfiguration) detector.getInterval()).toDuration().toMinutes(); + InitProgressProfile initProgress = computeInitProgressProfile(totalUpdates, intervalMins); + builder.initProgress(initProgress); + } + } + + listener.onResponse(builder.build()); + } +} diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java-e b/src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java-e new file mode 100644 index 000000000..90b7d350f --- /dev/null +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java-e @@ -0,0 +1,225 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchSecurityException; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.feature.Features; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.model.EntityAnomalyResult; +import org.opensearch.ad.util.MultiResponsesDelegateActionListener; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.model.FeatureData; + +/** + * Runner to trigger an anomaly detector. + */ +public final class AnomalyDetectorRunner { + + private final Logger logger = LogManager.getLogger(AnomalyDetectorRunner.class); + private final ModelManager modelManager; + private final FeatureManager featureManager; + private final int maxPreviewResults; + + public AnomalyDetectorRunner(ModelManager modelManager, FeatureManager featureManager, int maxPreviewResults) { + this.modelManager = modelManager; + this.featureManager = featureManager; + this.maxPreviewResults = maxPreviewResults; + } + + /** + * run anomaly detector and return anomaly result. + * + * @param detector anomaly detector instance + * @param startTime detection period start time + * @param endTime detection period end time + * @param context stored thread context + * @param listener handle anomaly result + * @throws IOException - if a user gives wrong query input when defining a detector + */ + public void executeDetector( + AnomalyDetector detector, + Instant startTime, + Instant endTime, + ThreadContext.StoredContext context, + ActionListener> listener + ) throws IOException { + context.restore(); + List categoryField = detector.getCategoryFields(); + if (categoryField != null && !categoryField.isEmpty()) { + featureManager.getPreviewEntities(detector, startTime.toEpochMilli(), endTime.toEpochMilli(), ActionListener.wrap(entities -> { + + if (entities == null || entities.isEmpty()) { + // TODO return exception like IllegalArgumentException to explain data is not enough for preview + // This also requires front-end change to handle error message correspondingly + // We return empty list for now to avoid breaking front-end + listener.onResponse(Collections.emptyList()); + return; + } + ActionListener entityAnomalyResultListener = ActionListener + .wrap( + entityAnomalyResult -> { listener.onResponse(entityAnomalyResult.getAnomalyResults()); }, + e -> onFailure(e, listener, detector.getId()) + ); + MultiResponsesDelegateActionListener multiEntitiesResponseListener = + new MultiResponsesDelegateActionListener( + entityAnomalyResultListener, + entities.size(), + String.format(Locale.ROOT, "Fail to get preview result for multi entity detector %s", detector.getId()), + true + ); + for (Entity entity : entities) { + featureManager + .getPreviewFeaturesForEntity( + detector, + entity, + startTime.toEpochMilli(), + endTime.toEpochMilli(), + ActionListener.wrap(features -> { + List entityResults = modelManager + .getPreviewResults(features.getProcessedFeatures(), detector.getShingleSize()); + List sampledEntityResults = sample( + parsePreviewResult(detector, features, entityResults, entity), + maxPreviewResults + ); + multiEntitiesResponseListener.onResponse(new EntityAnomalyResult(sampledEntityResults)); + }, e -> multiEntitiesResponseListener.onFailure(e)) + ); + } + }, e -> onFailure(e, listener, detector.getId()))); + } else { + featureManager.getPreviewFeatures(detector, startTime.toEpochMilli(), endTime.toEpochMilli(), ActionListener.wrap(features -> { + try { + List results = modelManager + .getPreviewResults(features.getProcessedFeatures(), detector.getShingleSize()); + listener.onResponse(sample(parsePreviewResult(detector, features, results, null), maxPreviewResults)); + } catch (Exception e) { + onFailure(e, listener, detector.getId()); + } + }, e -> onFailure(e, listener, detector.getId()))); + } + } + + private void onFailure(Exception e, ActionListener> listener, String detectorId) { + logger.info("Fail to preview anomaly detector " + detectorId, e); + // TODO return exception like IllegalArgumentException to explain data is not enough for preview + // This also requires front-end change to handle error message correspondingly + // We return empty list for now to avoid breaking front-end + if (e instanceof OpenSearchSecurityException) { + listener.onFailure(e); + return; + } + listener.onResponse(Collections.emptyList()); + } + + private List parsePreviewResult( + AnomalyDetector detector, + Features features, + List results, + Entity entity + ) { + // unprocessedFeatures[][], each row is for one date range. + // For example, unprocessedFeatures[0][2] is for the first time range, the third feature + double[][] unprocessedFeatures = features.getUnprocessedFeatures(); + List> timeRanges = features.getTimeRanges(); + List featureAttributes = detector.getFeatureAttributes().stream().filter(Feature::getEnabled).collect(Collectors.toList()); + + List anomalyResults = new ArrayList<>(); + if (timeRanges != null && timeRanges.size() > 0) { + for (int i = 0; i < timeRanges.size(); i++) { + Map.Entry timeRange = timeRanges.get(i); + + List featureDatas = new ArrayList<>(); + int featureSize = featureAttributes.size(); + for (int j = 0; j < featureSize; j++) { + double value = unprocessedFeatures[i][j]; + Feature feature = featureAttributes.get(j); + FeatureData data = new FeatureData(feature.getId(), feature.getName(), value); + featureDatas.add(data); + } + + AnomalyResult result; + if (results != null && results.size() > i) { + ThresholdingResult thresholdingResult = results.get(i); + List resultsToSave = thresholdingResult + .toIndexableResults( + detector, + Instant.ofEpochMilli(timeRange.getKey()), + Instant.ofEpochMilli(timeRange.getValue()), + null, + null, + featureDatas, + Optional.ofNullable(entity), + CommonValue.NO_SCHEMA_VERSION, + null, + null, + null + ); + for (AnomalyResult r : resultsToSave) { + anomalyResults.add(r); + } + } else { + result = new AnomalyResult( + detector.getId(), + null, + featureDatas, + Instant.ofEpochMilli(timeRange.getKey()), + Instant.ofEpochMilli(timeRange.getValue()), + null, + null, + null, + Optional.ofNullable(entity), + detector.getUser(), + CommonValue.NO_SCHEMA_VERSION, + null + ); + anomalyResults.add(result); + } + } + } + return anomalyResults; + } + + private List sample(List results, int sampleSize) { + if (results.size() <= sampleSize) { + return results; + } else { + double stepSize = (results.size() - 1.0) / (sampleSize - 1.0); + List samples = new ArrayList<>(sampleSize); + for (int i = 0; i < sampleSize; i++) { + int index = Math.min((int) (stepSize * i), results.size() - 1); + samples.add(results.get(index)); + } + return samples; + } + } + +} diff --git a/src/main/java/org/opensearch/ad/CleanState.java-e b/src/main/java/org/opensearch/ad/CleanState.java-e new file mode 100644 index 000000000..ae8085e88 --- /dev/null +++ b/src/main/java/org/opensearch/ad/CleanState.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +/** + * Represent a state organized via detectorId. When deleting a detector's state, + * we can remove it from the state. + * + * + */ +public interface CleanState { + /** + * Remove state associated with a detector Id + * @param detectorId Detector Id + */ + void clear(String detectorId); +} diff --git a/src/main/java/org/opensearch/ad/DetectorModelSize.java-e b/src/main/java/org/opensearch/ad/DetectorModelSize.java-e new file mode 100644 index 000000000..52e4660e6 --- /dev/null +++ b/src/main/java/org/opensearch/ad/DetectorModelSize.java-e @@ -0,0 +1,24 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import java.util.Map; + +public interface DetectorModelSize { + /** + * Gets all of a detector's model sizes hosted on a node + * + * @param detectorId Detector Id + * @return a map of model id to its memory size + */ + Map getModelSize(String detectorId); +} diff --git a/src/main/java/org/opensearch/ad/EntityProfileRunner.java b/src/main/java/org/opensearch/ad/EntityProfileRunner.java index 1078f8a59..491e8088f 100644 --- a/src/main/java/org/opensearch/ad/EntityProfileRunner.java +++ b/src/main/java/org/opensearch/ad/EntityProfileRunner.java @@ -11,7 +11,7 @@ package org.opensearch.ad; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.util.List; import java.util.Map; diff --git a/src/main/java/org/opensearch/ad/EntityProfileRunner.java-e b/src/main/java/org/opensearch/ad/EntityProfileRunner.java-e new file mode 100644 index 000000000..491e8088f --- /dev/null +++ b/src/main/java/org/opensearch/ad/EntityProfileRunner.java-e @@ -0,0 +1,474 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.model.EntityProfile; +import org.opensearch.ad.model.EntityProfileName; +import org.opensearch.ad.model.EntityState; +import org.opensearch.ad.model.InitProgressProfile; +import org.opensearch.ad.settings.ADNumericSetting; +import org.opensearch.ad.transport.EntityProfileAction; +import org.opensearch.ad.transport.EntityProfileRequest; +import org.opensearch.ad.transport.EntityProfileResponse; +import org.opensearch.ad.util.MultiResponsesDelegateActionListener; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.routing.Preference; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.NestedQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.util.ParseUtils; + +public class EntityProfileRunner extends AbstractProfileRunner { + private final Logger logger = LogManager.getLogger(EntityProfileRunner.class); + + static final String NOT_HC_DETECTOR_ERR_MSG = "This is not a high cardinality detector"; + static final String EMPTY_ENTITY_ATTRIBUTES = "Empty entity attributes"; + static final String NO_ENTITY = "Cannot find entity"; + private Client client; + private SecurityClientUtil clientUtil; + private NamedXContentRegistry xContentRegistry; + + public EntityProfileRunner(Client client, SecurityClientUtil clientUtil, NamedXContentRegistry xContentRegistry, long requiredSamples) { + super(requiredSamples); + this.client = client; + this.clientUtil = clientUtil; + this.xContentRegistry = xContentRegistry; + } + + /** + * Get profile info of specific entity. + * + * @param detectorId detector identifier + * @param entityValue entity value + * @param profilesToCollect profiles to collect + * @param listener action listener to handle exception and process entity profile response + */ + public void profile( + String detectorId, + Entity entityValue, + Set profilesToCollect, + ActionListener listener + ) { + if (profilesToCollect == null || profilesToCollect.size() == 0) { + listener.onFailure(new IllegalArgumentException(ADCommonMessages.EMPTY_PROFILES_COLLECT)); + return; + } + GetRequest getDetectorRequest = new GetRequest(CommonName.CONFIG_INDEX, detectorId); + + client.get(getDetectorRequest, ActionListener.wrap(getResponse -> { + if (getResponse != null && getResponse.isExists()) { + try ( + XContentParser parser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getResponse.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + AnomalyDetector detector = AnomalyDetector.parse(parser, detectorId); + List categoryFields = detector.getCategoryFields(); + int maxCategoryFields = ADNumericSetting.maxCategoricalFields(); + if (categoryFields == null || categoryFields.size() == 0) { + listener.onFailure(new IllegalArgumentException(NOT_HC_DETECTOR_ERR_MSG)); + } else if (categoryFields.size() > maxCategoryFields) { + listener.onFailure(new IllegalArgumentException(CommonMessages.getTooManyCategoricalFieldErr(maxCategoryFields))); + } else { + validateEntity(entityValue, categoryFields, detectorId, profilesToCollect, detector, listener); + } + } catch (Exception t) { + listener.onFailure(t); + } + } else { + listener.onFailure(new IllegalArgumentException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + detectorId)); + } + }, listener::onFailure)); + } + + /** + * Verify if the input entity exists or not in case of typos. + * + * If a user deletes the entity after job start, then we will not be able to + * get this entity in the index. For this case, we will not return a profile + * for this entity even if it's running on some data node. the entity's model + * will be deleted by another entity or by maintenance due to long inactivity. + * + * @param entity Entity accessor + * @param categoryFields category fields defined for a detector + * @param detectorId Detector Id + * @param profilesToCollect Profile to collect from the input + * @param detector Detector config accessor + * @param listener Callback to send responses. + */ + private void validateEntity( + Entity entity, + List categoryFields, + String detectorId, + Set profilesToCollect, + AnomalyDetector detector, + ActionListener listener + ) { + Map attributes = entity.getAttributes(); + if (attributes == null || attributes.size() != categoryFields.size()) { + listener.onFailure(new IllegalArgumentException(EMPTY_ENTITY_ATTRIBUTES)); + return; + } + for (String field : categoryFields) { + if (false == attributes.containsKey(field)) { + listener.onFailure(new IllegalArgumentException("Cannot find " + field)); + return; + } + } + + BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().filter(detector.getFilterQuery()); + + for (TermQueryBuilder term : entity.getTermQueryBuilders()) { + internalFilterQuery.filter(term); + } + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(internalFilterQuery).size(1); + + SearchRequest searchRequest = new SearchRequest(detector.getIndices().toArray(new String[0]), searchSourceBuilder) + .preference(Preference.LOCAL.toString()); + final ActionListener searchResponseListener = ActionListener.wrap(searchResponse -> { + try { + if (searchResponse.getHits().getHits().length == 0) { + listener.onFailure(new IllegalArgumentException(NO_ENTITY)); + return; + } + prepareEntityProfile(listener, detectorId, entity, profilesToCollect, detector, categoryFields.get(0)); + } catch (Exception e) { + listener.onFailure(new IllegalArgumentException(NO_ENTITY)); + return; + } + }, e -> listener.onFailure(new IllegalArgumentException(NO_ENTITY))); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + detector.getId(), + client, + searchResponseListener + ); + + } + + private void prepareEntityProfile( + ActionListener listener, + String detectorId, + Entity entityValue, + Set profilesToCollect, + AnomalyDetector detector, + String categoryField + ) { + EntityProfileRequest request = new EntityProfileRequest(detectorId, entityValue, profilesToCollect); + + client + .execute( + EntityProfileAction.INSTANCE, + request, + ActionListener.wrap(r -> getJob(detectorId, entityValue, profilesToCollect, detector, r, listener), listener::onFailure) + ); + } + + private void getJob( + String detectorId, + Entity entityValue, + Set profilesToCollect, + AnomalyDetector detector, + EntityProfileResponse entityProfileResponse, + ActionListener listener + ) { + GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX, detectorId); + client.get(getRequest, ActionListener.wrap(getResponse -> { + if (getResponse != null && getResponse.isExists()) { + try ( + XContentParser parser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getResponse.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + AnomalyDetectorJob job = AnomalyDetectorJob.parse(parser); + + int totalResponsesToWait = 0; + if (profilesToCollect.contains(EntityProfileName.INIT_PROGRESS) + || profilesToCollect.contains(EntityProfileName.STATE)) { + totalResponsesToWait++; + } + if (profilesToCollect.contains(EntityProfileName.ENTITY_INFO)) { + totalResponsesToWait++; + } + if (profilesToCollect.contains(EntityProfileName.MODELS)) { + totalResponsesToWait++; + } + MultiResponsesDelegateActionListener delegateListener = + new MultiResponsesDelegateActionListener( + listener, + totalResponsesToWait, + ADCommonMessages.FAIL_FETCH_ERR_MSG + entityValue + " of detector " + detectorId, + false + ); + + if (profilesToCollect.contains(EntityProfileName.MODELS)) { + EntityProfile.Builder builder = new EntityProfile.Builder(); + if (false == job.isEnabled()) { + delegateListener.onResponse(builder.build()); + } else { + delegateListener.onResponse(builder.modelProfile(entityProfileResponse.getModelProfile()).build()); + } + } + + if (profilesToCollect.contains(EntityProfileName.INIT_PROGRESS) + || profilesToCollect.contains(EntityProfileName.STATE)) { + profileStateRelated( + entityProfileResponse.getTotalUpdates(), + detectorId, + entityValue, + profilesToCollect, + detector, + job, + delegateListener + ); + } + + if (profilesToCollect.contains(EntityProfileName.ENTITY_INFO)) { + long enabledTimeMs = job.getEnabledTime().toEpochMilli(); + SearchRequest lastSampleTimeRequest = createLastSampleTimeRequest( + detectorId, + enabledTimeMs, + entityValue, + detector.getCustomResultIndex() + ); + + EntityProfile.Builder builder = new EntityProfile.Builder(); + + Optional isActiveOp = entityProfileResponse.isActive(); + if (isActiveOp.isPresent()) { + builder.isActive(isActiveOp.get()); + } + builder.lastActiveTimestampMs(entityProfileResponse.getLastActiveMs()); + + client.search(lastSampleTimeRequest, ActionListener.wrap(searchResponse -> { + Optional latestSampleTimeMs = ParseUtils.getLatestDataTime(searchResponse); + + if (latestSampleTimeMs.isPresent()) { + builder.lastSampleTimestampMs(latestSampleTimeMs.get()); + } + + delegateListener.onResponse(builder.build()); + }, exception -> { + // sth wrong like result index not created. Return what we have + if (exception instanceof IndexNotFoundException) { + // don't print out stack trace since it is not helpful + logger.info("Result index hasn't been created", exception.getMessage()); + } else { + logger.warn("fail to get last sample time", exception); + } + delegateListener.onResponse(builder.build()); + })); + } + } catch (Exception e) { + logger.error(ADCommonMessages.FAIL_TO_GET_PROFILE_MSG, e); + listener.onFailure(e); + } + } else { + sendUnknownState(profilesToCollect, entityValue, true, listener); + } + }, exception -> { + if (exception instanceof IndexNotFoundException) { + logger.info(exception.getMessage()); + sendUnknownState(profilesToCollect, entityValue, true, listener); + } else { + logger.error(ADCommonMessages.FAIL_TO_GET_PROFILE_MSG + detectorId, exception); + listener.onFailure(exception); + } + })); + } + + private void profileStateRelated( + long totalUpdates, + String detectorId, + Entity entityValue, + Set profilesToCollect, + AnomalyDetector detector, + AnomalyDetectorJob job, + MultiResponsesDelegateActionListener delegateListener + ) { + if (totalUpdates == 0) { + sendUnknownState(profilesToCollect, entityValue, false, delegateListener); + } else if (false == job.isEnabled()) { + sendUnknownState(profilesToCollect, entityValue, false, delegateListener); + } else if (totalUpdates >= requiredSamples) { + sendRunningState(profilesToCollect, entityValue, delegateListener); + } else { + sendInitState(profilesToCollect, entityValue, detector, totalUpdates, delegateListener); + } + } + + /** + * Send unknown state back + * @param profilesToCollect Profiles to Collect + * @param entityValue Entity value + * @param immediate whether we should terminate workflow and respond immediately + * @param delegateListener Delegate listener + */ + private void sendUnknownState( + Set profilesToCollect, + Entity entityValue, + boolean immediate, + ActionListener delegateListener + ) { + EntityProfile.Builder builder = new EntityProfile.Builder(); + if (profilesToCollect.contains(EntityProfileName.STATE)) { + builder.state(EntityState.UNKNOWN); + } + if (immediate) { + delegateListener.onResponse(builder.build()); + } else { + delegateListener.onResponse(builder.build()); + } + } + + private void sendRunningState( + Set profilesToCollect, + Entity entityValue, + MultiResponsesDelegateActionListener delegateListener + ) { + EntityProfile.Builder builder = new EntityProfile.Builder(); + if (profilesToCollect.contains(EntityProfileName.STATE)) { + builder.state(EntityState.RUNNING); + } + if (profilesToCollect.contains(EntityProfileName.INIT_PROGRESS)) { + InitProgressProfile initProgress = new InitProgressProfile("100%", 0, 0); + builder.initProgress(initProgress); + } + delegateListener.onResponse(builder.build()); + } + + private void sendInitState( + Set profilesToCollect, + Entity entityValue, + AnomalyDetector detector, + long updates, + MultiResponsesDelegateActionListener delegateListener + ) { + EntityProfile.Builder builder = new EntityProfile.Builder(); + if (profilesToCollect.contains(EntityProfileName.STATE)) { + builder.state(EntityState.INIT); + } + if (profilesToCollect.contains(EntityProfileName.INIT_PROGRESS)) { + long intervalMins = ((IntervalTimeConfiguration) detector.getInterval()).toDuration().toMinutes(); + InitProgressProfile initProgress = computeInitProgressProfile(updates, intervalMins); + builder.initProgress(initProgress); + } + delegateListener.onResponse(builder.build()); + } + + private SearchRequest createLastSampleTimeRequest(String detectorId, long enabledTime, Entity entity, String resultIndex) { + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + + String path = "entity"; + String entityName = path + ".name"; + String entityValue = path + ".value"; + + for (Map.Entry attribute : entity.getAttributes().entrySet()) { + /* + * each attribute pair corresponds to a nested query like + "nested": { + "query": { + "bool": { + "filter": [ + { + "term": { + "entity.name": { + "value": "turkey4", + "boost": 1 + } + } + }, + { + "term": { + "entity.value": { + "value": "Turkey", + "boost": 1 + } + } + } + ] + } + }, + "path": "entity", + "ignore_unmapped": false, + "score_mode": "none", + "boost": 1 + } + },*/ + BoolQueryBuilder nestedBoolQueryBuilder = new BoolQueryBuilder(); + + TermQueryBuilder entityNameFilterQuery = QueryBuilders.termQuery(entityName, attribute.getKey()); + nestedBoolQueryBuilder.filter(entityNameFilterQuery); + TermQueryBuilder entityValueFilterQuery = QueryBuilders.termQuery(entityValue, attribute.getValue()); + nestedBoolQueryBuilder.filter(entityValueFilterQuery); + + NestedQueryBuilder nestedNameQueryBuilder = new NestedQueryBuilder(path, nestedBoolQueryBuilder, ScoreMode.None); + boolQueryBuilder.filter(nestedNameQueryBuilder); + } + + boolQueryBuilder.filter(QueryBuilders.termQuery(AnomalyResult.DETECTOR_ID_FIELD, detectorId)); + + boolQueryBuilder.filter(QueryBuilders.rangeQuery(CommonName.EXECUTION_END_TIME_FIELD).gte(enabledTime)); + + SearchSourceBuilder source = new SearchSourceBuilder() + .query(boolQueryBuilder) + .aggregation(AggregationBuilders.max(CommonName.AGG_NAME_MAX_TIME).field(CommonName.EXECUTION_END_TIME_FIELD)) + .trackTotalHits(false) + .size(0); + + SearchRequest request = new SearchRequest(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS); + request.source(source); + if (resultIndex != null) { + request.indices(resultIndex); + } + return request; + } +} diff --git a/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java b/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java index 3d0f58ac7..4b05295ae 100644 --- a/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java +++ b/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java @@ -45,6 +45,7 @@ import org.opensearch.commons.authuser.User; import org.opensearch.search.SearchHits; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.common.exception.TimeSeriesException; @@ -164,7 +165,8 @@ private void updateRealtimeTask(AnomalyResultResponse response, String detectorI // real time init progress is 0 may mean this is a newly started detector // Delay real time cache update by one minute. If we are in init status, the delay may give the model training time to // finish. We can change the detector running immediately instead of waiting for the next interval. - threadPool.schedule(profileHCInitProgress, new TimeValue(60, TimeUnit.SECONDS), AnomalyDetectorPlugin.AD_THREAD_POOL_NAME); + threadPool + .schedule(profileHCInitProgress, new TimeValue(60, TimeUnit.SECONDS), TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME); } else { profileHCInitProgress.run(); } @@ -317,7 +319,7 @@ public void indexAnomalyResultException( log.error("Fail to execute RCFRollingAction", e); updateLatestRealtimeTask(detectorId, taskState, null, null, errorMessage); })); - }, new TimeValue(60, TimeUnit.SECONDS), AnomalyDetectorPlugin.AD_THREAD_POOL_NAME); + }, new TimeValue(60, TimeUnit.SECONDS), TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME); } else { updateLatestRealtimeTask(detectorId, taskState, null, null, errorMessage); } diff --git a/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java-e b/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java-e new file mode 100644 index 000000000..4b05295ae --- /dev/null +++ b/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java-e @@ -0,0 +1,381 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import static org.opensearch.ad.constant.ADCommonMessages.CAN_NOT_FIND_LATEST_TASK; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.model.DetectorProfileName; +import org.opensearch.ad.task.ADTaskCacheManager; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.AnomalyResultResponse; +import org.opensearch.ad.transport.ProfileAction; +import org.opensearch.ad.transport.ProfileRequest; +import org.opensearch.ad.transport.RCFPollingAction; +import org.opensearch.ad.transport.RCFPollingRequest; +import org.opensearch.ad.transport.handler.AnomalyIndexHandler; +import org.opensearch.ad.util.ExceptionUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.search.SearchHits; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +public class ExecuteADResultResponseRecorder { + private static final Logger log = LogManager.getLogger(ExecuteADResultResponseRecorder.class); + + private ADIndexManagement anomalyDetectionIndices; + private AnomalyIndexHandler anomalyResultHandler; + private ADTaskManager adTaskManager; + private DiscoveryNodeFilterer nodeFilter; + private ThreadPool threadPool; + private Client client; + private NodeStateManager nodeStateManager; + private ADTaskCacheManager adTaskCacheManager; + private int rcfMinSamples; + + public ExecuteADResultResponseRecorder( + ADIndexManagement anomalyDetectionIndices, + AnomalyIndexHandler anomalyResultHandler, + ADTaskManager adTaskManager, + DiscoveryNodeFilterer nodeFilter, + ThreadPool threadPool, + Client client, + NodeStateManager nodeStateManager, + ADTaskCacheManager adTaskCacheManager, + int rcfMinSamples + ) { + this.anomalyDetectionIndices = anomalyDetectionIndices; + this.anomalyResultHandler = anomalyResultHandler; + this.adTaskManager = adTaskManager; + this.nodeFilter = nodeFilter; + this.threadPool = threadPool; + this.client = client; + this.nodeStateManager = nodeStateManager; + this.adTaskCacheManager = adTaskCacheManager; + this.rcfMinSamples = rcfMinSamples; + } + + public void indexAnomalyResult( + Instant detectionStartTime, + Instant executionStartTime, + AnomalyResultResponse response, + AnomalyDetector detector + ) { + String detectorId = detector.getId(); + try { + // skipping writing to the result index if not necessary + // For a single-entity detector, the result is not useful if error is null + // and rcf score (thus anomaly grade/confidence) is null. + // For a HCAD detector, we don't need to save on the detector level. + // We return 0 or Double.NaN rcf score if there is no error. + if ((response.getAnomalyScore() <= 0 || Double.isNaN(response.getAnomalyScore())) && response.getError() == null) { + updateRealtimeTask(response, detectorId); + return; + } + IntervalTimeConfiguration windowDelay = (IntervalTimeConfiguration) detector.getWindowDelay(); + Instant dataStartTime = detectionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); + Instant dataEndTime = executionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); + User user = detector.getUser(); + + if (response.getError() != null) { + log.info("Anomaly result action run successfully for {} with error {}", detectorId, response.getError()); + } + + AnomalyResult anomalyResult = response + .toAnomalyResult( + detectorId, + dataStartTime, + dataEndTime, + executionStartTime, + Instant.now(), + anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT), + user, + response.getError() + ); + + String resultIndex = detector.getCustomResultIndex(); + anomalyResultHandler.index(anomalyResult, detectorId, resultIndex); + updateRealtimeTask(response, detectorId); + } catch (EndRunException e) { + throw e; + } catch (Exception e) { + log.error("Failed to index anomaly result for " + detectorId, e); + } + } + + /** + * Update real time task (one document per detector in state index). If the real-time task has no changes compared with local cache, + * the task won't update. Task only updates when the state changed, or any error happened, or AD job stopped. Task is mainly consumed + * by the front-end to track detector status. For single-stream detectors, we embed model total updates in AnomalyResultResponse and + * update state accordingly. For HCAD, we won't wait for model finishing updating before returning a response to the job scheduler + * since it might be long before all entities finish execution. So we don't embed model total updates in AnomalyResultResponse. + * Instead, we issue a profile request to poll each model node and get the maximum total updates among all models. + * @param response response returned from executing AnomalyResultAction + * @param detectorId Detector Id + */ + private void updateRealtimeTask(AnomalyResultResponse response, String detectorId) { + if (response.isHCDetector() != null && response.isHCDetector()) { + if (adTaskManager.skipUpdateHCRealtimeTask(detectorId, response.getError())) { + return; + } + DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); + Set profiles = new HashSet<>(); + profiles.add(DetectorProfileName.INIT_PROGRESS); + ProfileRequest profileRequest = new ProfileRequest(detectorId, profiles, true, dataNodes); + Runnable profileHCInitProgress = () -> { + client.execute(ProfileAction.INSTANCE, profileRequest, ActionListener.wrap(r -> { + log.debug("Update latest realtime task for HC detector {}, total updates: {}", detectorId, r.getTotalUpdates()); + updateLatestRealtimeTask(detectorId, null, r.getTotalUpdates(), response.getIntervalInMinutes(), response.getError()); + }, e -> { log.error("Failed to update latest realtime task for " + detectorId, e); })); + }; + if (!adTaskManager.isHCRealtimeTaskStartInitializing(detectorId)) { + // real time init progress is 0 may mean this is a newly started detector + // Delay real time cache update by one minute. If we are in init status, the delay may give the model training time to + // finish. We can change the detector running immediately instead of waiting for the next interval. + threadPool + .schedule(profileHCInitProgress, new TimeValue(60, TimeUnit.SECONDS), TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME); + } else { + profileHCInitProgress.run(); + } + + } else { + log + .debug( + "Update latest realtime task for single stream detector {}, total updates: {}", + detectorId, + response.getRcfTotalUpdates() + ); + updateLatestRealtimeTask(detectorId, null, response.getRcfTotalUpdates(), response.getIntervalInMinutes(), response.getError()); + } + } + + private void updateLatestRealtimeTask( + String detectorId, + String taskState, + Long rcfTotalUpdates, + Long detectorIntervalInMinutes, + String error + ) { + // Don't need info as this will be printed repeatedly in each interval + ActionListener listener = ActionListener.wrap(r -> { + if (r != null) { + log.debug("Updated latest realtime task successfully for detector {}, taskState: {}", detectorId, taskState); + } + }, e -> { + if ((e instanceof ResourceNotFoundException) && e.getMessage().contains(CAN_NOT_FIND_LATEST_TASK)) { + // Clear realtime task cache, will recreate AD task in next run, check AnomalyResultTransportAction. + log.error("Can't find latest realtime task of detector " + detectorId); + adTaskManager.removeRealtimeTaskCache(detectorId); + } else { + log.error("Failed to update latest realtime task for detector " + detectorId, e); + } + }); + + // rcfTotalUpdates is null when we save exception messages + if (!adTaskCacheManager.hasQueriedResultIndex(detectorId) && rcfTotalUpdates != null && rcfTotalUpdates < rcfMinSamples) { + // confirm the total updates number since it is possible that we have already had results after job enabling time + // If yes, total updates should be at least rcfMinSamples so that the init progress reaches 100%. + confirmTotalRCFUpdatesFound( + detectorId, + taskState, + rcfTotalUpdates, + detectorIntervalInMinutes, + error, + ActionListener + .wrap( + r -> adTaskManager + .updateLatestRealtimeTaskOnCoordinatingNode( + detectorId, + taskState, + r, + detectorIntervalInMinutes, + error, + listener + ), + e -> { + log.error("Fail to confirm rcf update", e); + adTaskManager + .updateLatestRealtimeTaskOnCoordinatingNode( + detectorId, + taskState, + rcfTotalUpdates, + detectorIntervalInMinutes, + error, + listener + ); + } + ) + ); + } else { + adTaskManager + .updateLatestRealtimeTaskOnCoordinatingNode( + detectorId, + taskState, + rcfTotalUpdates, + detectorIntervalInMinutes, + error, + listener + ); + } + } + + /** + * The function is not only indexing the result with the exception, but also updating the task state after + * 60s if the exception is related to cold start (index not found exceptions) for a single stream detector. + * + * @param detectionStartTime execution start time + * @param executionStartTime execution end time + * @param errorMessage Error message to record + * @param taskState AD task state (e.g., stopped) + * @param detector Detector config accessor + */ + public void indexAnomalyResultException( + Instant detectionStartTime, + Instant executionStartTime, + String errorMessage, + String taskState, + AnomalyDetector detector + ) { + String detectorId = detector.getId(); + try { + IntervalTimeConfiguration windowDelay = (IntervalTimeConfiguration) detector.getWindowDelay(); + Instant dataStartTime = detectionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); + Instant dataEndTime = executionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); + User user = detector.getUser(); + + AnomalyResult anomalyResult = new AnomalyResult( + detectorId, + null, // no task id + new ArrayList(), + dataStartTime, + dataEndTime, + executionStartTime, + Instant.now(), + errorMessage, + Optional.empty(), // single-stream detectors have no entity + user, + anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT), + null // no model id + ); + String resultIndex = detector.getCustomResultIndex(); + if (resultIndex != null && !anomalyDetectionIndices.doesIndexExist(resultIndex)) { + // Set result index as null, will write exception to default result index. + anomalyResultHandler.index(anomalyResult, detectorId, null); + } else { + anomalyResultHandler.index(anomalyResult, detectorId, resultIndex); + } + + if (errorMessage.contains(ADCommonMessages.NO_MODEL_ERR_MSG) && !detector.isHighCardinality()) { + // single stream detector raises ResourceNotFoundException containing CommonErrorMessages.NO_CHECKPOINT_ERR_MSG + // when there is no checkpoint. + // Delay real time cache update by one minute so we will have trained models by then and update the state + // document accordingly. + threadPool.schedule(() -> { + RCFPollingRequest request = new RCFPollingRequest(detectorId); + client.execute(RCFPollingAction.INSTANCE, request, ActionListener.wrap(rcfPollResponse -> { + long totalUpdates = rcfPollResponse.getTotalUpdates(); + // if there are updates, don't record failures + updateLatestRealtimeTask( + detectorId, + taskState, + totalUpdates, + detector.getIntervalInMinutes(), + totalUpdates > 0 ? "" : errorMessage + ); + }, e -> { + log.error("Fail to execute RCFRollingAction", e); + updateLatestRealtimeTask(detectorId, taskState, null, null, errorMessage); + })); + }, new TimeValue(60, TimeUnit.SECONDS), TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME); + } else { + updateLatestRealtimeTask(detectorId, taskState, null, null, errorMessage); + } + + } catch (Exception e) { + log.error("Failed to index anomaly result for " + detectorId, e); + } + } + + private void confirmTotalRCFUpdatesFound( + String detectorId, + String taskState, + Long rcfTotalUpdates, + Long detectorIntervalInMinutes, + String error, + ActionListener listener + ) { + nodeStateManager.getAnomalyDetector(detectorId, ActionListener.wrap(detectorOptional -> { + if (!detectorOptional.isPresent()) { + listener.onFailure(new TimeSeriesException(detectorId, "fail to get detector")); + return; + } + nodeStateManager.getAnomalyDetectorJob(detectorId, ActionListener.wrap(jobOptional -> { + if (!jobOptional.isPresent()) { + listener.onFailure(new TimeSeriesException(detectorId, "fail to get job")); + return; + } + + ProfileUtil + .confirmDetectorRealtimeInitStatus( + detectorOptional.get(), + jobOptional.get().getEnabledTime().toEpochMilli(), + client, + ActionListener.wrap(searchResponse -> { + ActionListener.completeWith(listener, () -> { + SearchHits hits = searchResponse.getHits(); + Long correctedTotalUpdates = rcfTotalUpdates; + if (hits.getTotalHits().value > 0L) { + // correct the number if we have already had results after job enabling time + // so that the detector won't stay initialized + correctedTotalUpdates = Long.valueOf(rcfMinSamples); + } + adTaskCacheManager.markResultIndexQueried(detectorId); + return correctedTotalUpdates; + }); + }, exception -> { + if (ExceptionUtil.isIndexNotAvailable(exception)) { + // anomaly result index is not created yet + adTaskCacheManager.markResultIndexQueried(detectorId); + listener.onResponse(0L); + } else { + listener.onFailure(exception); + } + }) + ); + }, e -> listener.onFailure(new TimeSeriesException(detectorId, "fail to get job")))); + }, e -> listener.onFailure(new TimeSeriesException(detectorId, "fail to get detector")))); + } +} diff --git a/src/main/java/org/opensearch/ad/ExpiringState.java-e b/src/main/java/org/opensearch/ad/ExpiringState.java-e new file mode 100644 index 000000000..0df0e1f51 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ExpiringState.java-e @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import java.time.Duration; +import java.time.Instant; + +/** + * Represent a state that can be expired with a duration if not accessed + * + */ +public interface ExpiringState { + default boolean expired(Instant lastAccessTime, Duration stateTtl, Instant now) { + return lastAccessTime.plus(stateTtl).isBefore(now); + } + + boolean expired(Duration stateTtl); +} diff --git a/src/main/java/org/opensearch/ad/MaintenanceState.java-e b/src/main/java/org/opensearch/ad/MaintenanceState.java-e new file mode 100644 index 000000000..646715f7a --- /dev/null +++ b/src/main/java/org/opensearch/ad/MaintenanceState.java-e @@ -0,0 +1,36 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import java.time.Duration; +import java.util.Map; + +/** + * Represent a state that needs to maintain its metadata regularly + * + * + */ +public interface MaintenanceState { + default void maintenance(Map stateToClean, Duration stateTtl) { + stateToClean.entrySet().stream().forEach(entry -> { + K detectorId = entry.getKey(); + + V state = entry.getValue(); + if (state.expired(stateTtl)) { + stateToClean.remove(detectorId); + } + + }); + } + + void maintenance(); +} diff --git a/src/main/java/org/opensearch/ad/MemoryTracker.java-e b/src/main/java/org/opensearch/ad/MemoryTracker.java-e new file mode 100644 index 000000000..1e40ef47a --- /dev/null +++ b/src/main/java/org/opensearch/ad/MemoryTracker.java-e @@ -0,0 +1,363 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE; + +import java.util.EnumMap; +import java.util.Locale; +import java.util.Map; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.monitor.jvm.JvmService; +import org.opensearch.timeseries.common.exception.LimitExceededException; + +import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * Class to track AD memory usage. + * + */ +public class MemoryTracker { + private static final Logger LOG = LogManager.getLogger(MemoryTracker.class); + + public enum Origin { + SINGLE_ENTITY_DETECTOR, + HC_DETECTOR, + HISTORICAL_SINGLE_ENTITY_DETECTOR, + } + + // memory tracker for total consumption of bytes + private long totalMemoryBytes; + private final Map totalMemoryBytesByOrigin; + // reserved for models. Cannot be deleted at will. + private long reservedMemoryBytes; + private final Map reservedMemoryBytesByOrigin; + private long heapSize; + private long heapLimitBytes; + private long desiredModelSize; + // we observe threshold model uses a fixed size array and the size is the same + private int thresholdModelBytes; + private ADCircuitBreakerService adCircuitBreakerService; + + /** + * Constructor + * + * @param jvmService Service providing jvm info + * @param modelMaxSizePercentage Percentage of heap for the max size of a model + * @param modelDesiredSizePercentage percentage of heap for the desired size of a model + * @param clusterService Cluster service object + * @param adCircuitBreakerService Memory circuit breaker + */ + public MemoryTracker( + JvmService jvmService, + double modelMaxSizePercentage, + double modelDesiredSizePercentage, + ClusterService clusterService, + ADCircuitBreakerService adCircuitBreakerService + ) { + this.totalMemoryBytes = 0; + this.totalMemoryBytesByOrigin = new EnumMap(Origin.class); + this.reservedMemoryBytes = 0; + this.reservedMemoryBytesByOrigin = new EnumMap(Origin.class); + this.heapSize = jvmService.info().getMem().getHeapMax().getBytes(); + this.heapLimitBytes = (long) (heapSize * modelMaxSizePercentage); + this.desiredModelSize = (long) (heapSize * modelDesiredSizePercentage); + if (clusterService != null) { + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(MODEL_MAX_SIZE_PERCENTAGE, it -> this.heapLimitBytes = (long) (heapSize * it)); + } + + this.thresholdModelBytes = 180_000; + this.adCircuitBreakerService = adCircuitBreakerService; + } + + /** + * This function derives from the old code: https://tinyurl.com/2eaabja6 + * + * @param detectorId Detector Id + * @param trcf Thresholded random cut forest model + * @return true if there is enough memory; otherwise throw LimitExceededException. + */ + public synchronized boolean isHostingAllowed(String detectorId, ThresholdedRandomCutForest trcf) { + long requiredBytes = estimateTRCFModelSize(trcf); + if (canAllocateReserved(requiredBytes)) { + return true; + } else { + throw new LimitExceededException( + detectorId, + String + .format( + Locale.ROOT, + "Exceeded memory limit. New size is %d bytes and max limit is %d bytes", + reservedMemoryBytes + requiredBytes, + heapLimitBytes + ) + ); + } + } + + /** + * @param requiredBytes required bytes to allocate + * @return whether there is enough memory for the required bytes. This is + * true when circuit breaker is closed and there is enough reserved memory. + */ + public synchronized boolean canAllocateReserved(long requiredBytes) { + return (false == adCircuitBreakerService.isOpen() && reservedMemoryBytes + requiredBytes <= heapLimitBytes); + } + + /** + * @param bytes required bytes + * @return whether there is enough memory for the required bytes. This is + * true when circuit breaker is closed and there is enough overall memory. + */ + public synchronized boolean canAllocate(long bytes) { + return false == adCircuitBreakerService.isOpen() && totalMemoryBytes + bytes <= heapLimitBytes; + } + + public synchronized void consumeMemory(long memoryToConsume, boolean reserved, Origin origin) { + totalMemoryBytes += memoryToConsume; + adjustOriginMemoryConsumption(memoryToConsume, origin, totalMemoryBytesByOrigin); + if (reserved) { + reservedMemoryBytes += memoryToConsume; + adjustOriginMemoryConsumption(memoryToConsume, origin, reservedMemoryBytesByOrigin); + } + } + + private void adjustOriginMemoryConsumption(long memoryToConsume, Origin origin, Map mapToUpdate) { + Long originTotalMemoryBytes = mapToUpdate.getOrDefault(origin, 0L); + mapToUpdate.put(origin, originTotalMemoryBytes + memoryToConsume); + } + + public synchronized void releaseMemory(long memoryToShed, boolean reserved, Origin origin) { + totalMemoryBytes -= memoryToShed; + adjustOriginMemoryRelease(memoryToShed, origin, totalMemoryBytesByOrigin); + if (reserved) { + reservedMemoryBytes -= memoryToShed; + adjustOriginMemoryRelease(memoryToShed, origin, reservedMemoryBytesByOrigin); + } + } + + private void adjustOriginMemoryRelease(long memoryToConsume, Origin origin, Map mapToUpdate) { + Long originTotalMemoryBytes = mapToUpdate.get(origin); + if (originTotalMemoryBytes != null) { + mapToUpdate.put(origin, originTotalMemoryBytes - memoryToConsume); + } + } + + /** + * Gets the estimated size of an entity's model. + * + * @param trcf ThresholdedRandomCutForest object + * @return estimated model size in bytes + */ + public long estimateTRCFModelSize(ThresholdedRandomCutForest trcf) { + RandomCutForest forest = trcf.getForest(); + return estimateTRCFModelSize( + forest.getDimensions(), + forest.getNumberOfTrees(), + forest.getBoundingBoxCacheFraction(), + forest.getShingleSize(), + forest.isInternalShinglingEnabled() + ); + } + + /** + * Gets the estimated size of an entity's model. + * + * RCF size: + * Assume the sample size is 256. I measured the memory size of a ThresholdedRandomCutForest + * using heap dump. A ThresholdedRandomCutForest comprises a compact rcf model and + * a threshold model. + * + * A compact RCF forest consists of: + * - Random number generator: 56 bytes + * - PointStoreCoordinator: 24 bytes + * - SequentialForestUpdateExecutor: 24 bytes + * - SequentialForestTraversalExecutor: 16 bytes + * - PointStoreFloat + * + IndexManager + * - int array for free indexes: 256 * numberOfTrees * 4, where 4 is the size of an integer + * - two int array for locationList and refCount: 256 * numberOfTrees * 4 bytes * 2 + * - a float array for data store: 256 * trees * dimension * 4 bytes: due to various + * optimization like shingleSize(dimensions), we don't use all of the array. The average + * usage percentage depends on shingle size and if internal shingling is enabled. + * I did experiments with power-of-two shingle sizes and internal shingling on/off + * by running ThresholdedRandomCutForest over a million points. + * My experiment shows that + * * if internal shingling is off, data store is filled at full + * capacity. + * * otherwise, data store usage depends on shingle size: + * + * Shingle Size usage + * 1 1 + * 2 0.53 + * 4 0.27 + * 8 0.27 + * 16 0.13 + * 32 0.07 + * 64 0.07 + * + * The formula reflects the data and fits the point store usage to the closest power-of-two case. + * For example, if shingle size is 17, we use the usage 0.13 since it is closer to 16. + * + * {@code IF(dimensions>=32, 1/(LOG(dimensions+1, 2)+LOG(dimensions+1, 10)), 1/LOG(dimensions+1, 2))} + * where LOG gets the logarithm of a number and the syntax of LOG is {@code LOG (number, [base])}. + * We derive the formula by observing the point store usage ratio is a decreasing function of dimensions + * and the relationship is logarithm. Adding 1 to dimension to ensure dimension 1 results in a ratio 1. + * - ComponentList: an array of size numberOfTrees + * + SamplerPlusTree + * - CompactSampler: 2248 + * + CompactRandomCutTreeFloat + * - other fields: 152 + * - SmallNodeStore (small node store since our sample size is 256, less than the max of short): 6120 + * + BoxCacheFloat + * - other: 104 + * - BoundingBoxFloat: (1040 + 255* ((dimension * 4 + 16) * 2 + 32)) * actual bounding box cache usage, + * {@code actual bounding box cache usage = (bounding box cache fraction >= 0.3? 1: bounding box cache fraction)} + * {@code >= 0.3} we will still initialize bounding box cache array of the max size. + * 1040 is the size of BoundingBoxFloat's fields unrelated to tree size (255 nodes in our formula) + * In total, RCF size is + * 56 + # trees * (2248 + 152 + 6120 + 104 + (1040 + 255* (dimension * 4 + 16) * 2 + 32)) * adjusted bounding box cache ratio) + + * (256 * # trees * 2 + 256 * # trees * dimension) * 4 bytes * point store ratio + 30744 * 2 + 15432 + 208) + 24 + 24 + 16 + * = 56 + # trees * (8624 + (1040 + 255 * (dimension * 8 + 64)) * actual bounding box cache usage) + 256 * # trees * + * dimension * 4 * point store ratio + 77192 + * + * Thresholder size + * + Preprocessor: + * - lastShingledInput and lastShingledPoint: 2*(dimension*8 + 16) (2 due to 2 double arrays, 16 are array object size) + * - previousTimeStamps: shingle*8 + * - other: 248 + * - BasicThrehsolder: 256 + * + lastAnomalyAttribution: + * - high and low: 2*(dimension*8 + 16)(2 due to 2 double arrays, 16 are array object) + * - other 24 + * - lastAnomalyPoint and lastExpectedPoint: 2*(dimension*8 + 16) + * - other like ThresholdedRandomCutForest object size: 96 + * In total, thresholder size is: + * 6*(dimension*8 + 16) + shingle*8 + 248 + 256 + 24 + 96 + * = 6*(dimension*8 + 16) + shingle*8 + 624 + * + * @param dimension The number of feature dimensions in RCF + * @param numberOfTrees The number of trees in RCF + * @param boundingBoxCacheFraction Bounding box cache usage in RCF + * @param shingleSize shingle size + * @param internalShingling whether internal shingling is enabled or not + * @return estimated TRCF model size + * + * @throws IllegalArgumentException when the input shingle size is out of range [1, 64] + */ + public long estimateTRCFModelSize( + int dimension, + int numberOfTrees, + double boundingBoxCacheFraction, + int shingleSize, + boolean internalShingling + ) { + double averagePointStoreUsage = 0; + if (!internalShingling || shingleSize == 1) { + averagePointStoreUsage = 1; + } else if (shingleSize <= 3) { + averagePointStoreUsage = 0.53; + } else if (shingleSize <= 12) { + averagePointStoreUsage = 0.27; + } else if (shingleSize <= 24) { + averagePointStoreUsage = 0.13; + } else if (shingleSize <= 64) { + averagePointStoreUsage = 0.07; + } else { + throw new IllegalArgumentException("out of range shingle size " + shingleSize); + } + + double actualBoundingBoxUsage = boundingBoxCacheFraction >= 0.3 ? 1d : boundingBoxCacheFraction; + long compactRcfSize = (long) (56 + numberOfTrees * (8624 + (1040 + 255 * (dimension * 8 + 64)) * actualBoundingBoxUsage) + 256 + * numberOfTrees * dimension * 4 * averagePointStoreUsage + 77192); + long thresholdSize = 6 * (dimension * 8 + 16) + shingleSize * 8 + 624; + return compactRcfSize + thresholdSize; + } + + /** + * Bytes to remove to keep AD memory usage within the limit + * @return bytes to remove + */ + public synchronized long memoryToShed() { + return totalMemoryBytes - heapLimitBytes; + } + + /** + * + * @return Allowed heap usage in bytes by AD models + */ + public long getHeapLimit() { + return heapLimitBytes; + } + + /** + * + * @return Desired model partition size in bytes + */ + public long getDesiredModelSize() { + return desiredModelSize; + } + + public long getTotalMemoryBytes() { + return totalMemoryBytes; + } + + /** + * In case of bugs/race conditions or users dyanmically changing dedicated/shared + * cache size, sync used bytes infrequently by recomputing memory usage. + * @param origin Origin + * @param totalBytes total bytes from recomputing + * @param reservedBytes reserved bytes from recomputing + * @return whether memory adjusted due to mismatch + */ + public synchronized boolean syncMemoryState(Origin origin, long totalBytes, long reservedBytes) { + long recordedTotalBytes = totalMemoryBytesByOrigin.getOrDefault(origin, 0L); + long recordedReservedBytes = reservedMemoryBytesByOrigin.getOrDefault(origin, 0L); + if (totalBytes == recordedTotalBytes && reservedBytes == recordedReservedBytes) { + return false; + } + + LOG + .info( + String + .format( + Locale.ROOT, + "Memory states do not match. Recorded: total bytes %d, reserved bytes %d." + + "Actual: total bytes %d, reserved bytes: %d", + recordedTotalBytes, + recordedReservedBytes, + totalBytes, + reservedBytes + ) + ); + // reserved bytes mismatch + long reservedDiff = reservedBytes - recordedReservedBytes; + reservedMemoryBytesByOrigin.put(origin, reservedBytes); + reservedMemoryBytes += reservedDiff; + + long totalDiff = totalBytes - recordedTotalBytes; + totalMemoryBytesByOrigin.put(origin, totalBytes); + totalMemoryBytes += totalDiff; + return true; + } + + public int getThresholdModelBytes() { + return thresholdModelBytes; + } +} diff --git a/src/main/java/org/opensearch/ad/NodeState.java-e b/src/main/java/org/opensearch/ad/NodeState.java-e new file mode 100644 index 000000000..9c4693cbd --- /dev/null +++ b/src/main/java/org/opensearch/ad/NodeState.java-e @@ -0,0 +1,206 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Optional; + +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; + +/** + * Storing intermediate state during the execution of transport action + * + */ +public class NodeState implements ExpiringState { + private String detectorId; + // detector definition + private AnomalyDetector detectorDef; + // number of partitions + private int partitonNumber; + // last access time + private Instant lastAccessTime; + // last detection error recorded in result index. Used by DetectorStateHandler + // to check if the error for a detector has changed or not. If changed, trigger indexing. + private Optional lastDetectionError; + // last error. + private Optional exception; + // flag indicating whether checkpoint for the detector exists + private boolean checkPointExists; + // clock to get current time + private final Clock clock; + // cold start running flag to prevent concurrent cold start + private boolean coldStartRunning; + // detector job + private AnomalyDetectorJob detectorJob; + + public NodeState(String detectorId, Clock clock) { + this.detectorId = detectorId; + this.detectorDef = null; + this.partitonNumber = -1; + this.lastAccessTime = clock.instant(); + this.lastDetectionError = Optional.empty(); + this.exception = Optional.empty(); + this.checkPointExists = false; + this.clock = clock; + this.coldStartRunning = false; + this.detectorJob = null; + } + + public String getId() { + return detectorId; + } + + /** + * + * @return Detector configuration object + */ + public AnomalyDetector getDetectorDef() { + refreshLastUpdateTime(); + return detectorDef; + } + + /** + * + * @param detectorDef Detector configuration object + */ + public void setDetectorDef(AnomalyDetector detectorDef) { + this.detectorDef = detectorDef; + refreshLastUpdateTime(); + } + + /** + * + * @return RCF partition number of the detector + */ + public int getPartitonNumber() { + refreshLastUpdateTime(); + return partitonNumber; + } + + /** + * + * @param partitonNumber RCF partition number + */ + public void setPartitonNumber(int partitonNumber) { + this.partitonNumber = partitonNumber; + refreshLastUpdateTime(); + } + + /** + * Used to indicate whether cold start succeeds or not + * @return whether checkpoint of models exists or not. + */ + public boolean doesCheckpointExists() { + refreshLastUpdateTime(); + return checkPointExists; + } + + /** + * + * @param checkpointExists mark whether checkpoint of models exists or not. + */ + public void setCheckpointExists(boolean checkpointExists) { + refreshLastUpdateTime(); + this.checkPointExists = checkpointExists; + }; + + /** + * + * @return last model inference error + */ + public Optional getLastDetectionError() { + refreshLastUpdateTime(); + return lastDetectionError; + } + + /** + * + * @param lastError last model inference error + */ + public void setLastDetectionError(String lastError) { + this.lastDetectionError = Optional.ofNullable(lastError); + refreshLastUpdateTime(); + } + + /** + * + * @return last exception if any + */ + public Optional getException() { + refreshLastUpdateTime(); + return exception; + } + + /** + * + * @param exception exception to record + */ + public void setException(Exception exception) { + this.exception = Optional.ofNullable(exception); + refreshLastUpdateTime(); + } + + /** + * Used to prevent concurrent cold start + * @return whether cold start is running or not + */ + public boolean isColdStartRunning() { + refreshLastUpdateTime(); + return coldStartRunning; + } + + /** + * + * @param coldStartRunning whether cold start is running or not + */ + public void setColdStartRunning(boolean coldStartRunning) { + this.coldStartRunning = coldStartRunning; + refreshLastUpdateTime(); + } + + /** + * + * @return Detector configuration object + */ + public AnomalyDetectorJob getDetectorJob() { + refreshLastUpdateTime(); + return detectorJob; + } + + /** + * + * @param detectorJob Detector job + */ + public void setDetectorJob(AnomalyDetectorJob detectorJob) { + this.detectorJob = detectorJob; + refreshLastUpdateTime(); + } + + /** + * refresh last access time. + */ + private void refreshLastUpdateTime() { + lastAccessTime = clock.instant(); + } + + /** + * @param stateTtl time to leave for the state + * @return whether the transport state is expired + */ + @Override + public boolean expired(Duration stateTtl) { + return expired(lastAccessTime, stateTtl, clock.instant()); + } +} diff --git a/src/main/java/org/opensearch/ad/NodeStateManager.java b/src/main/java/org/opensearch/ad/NodeStateManager.java index e99cfdbe8..7e3d708c2 100644 --- a/src/main/java/org/opensearch/ad/NodeStateManager.java +++ b/src/main/java/org/opensearch/ad/NodeStateManager.java @@ -13,7 +13,7 @@ import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.time.Clock; import java.time.Duration; diff --git a/src/main/java/org/opensearch/ad/NodeStateManager.java-e b/src/main/java/org/opensearch/ad/NodeStateManager.java-e new file mode 100644 index 000000000..7e3d708c2 --- /dev/null +++ b/src/main/java/org/opensearch/ad/NodeStateManager.java-e @@ -0,0 +1,408 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.time.Clock; +import java.time.Duration; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.apache.logging.log4j.util.Strings; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.ml.SingleStreamModelIdMapper; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.transport.BackPressureRouting; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.ad.util.ExceptionUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.lease.Releasable; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; + +/** + * NodeStateManager is used to manage states shared by transport and ml components + * like AnomalyDetector object + * + */ +public class NodeStateManager implements MaintenanceState, CleanState { + private static final Logger LOG = LogManager.getLogger(NodeStateManager.class); + public static final String NO_ERROR = "no_error"; + private ConcurrentHashMap states; + private Client client; + private NamedXContentRegistry xContentRegistry; + private ClientUtil clientUtil; + // map from detector id to the map of ES node id to the node's backpressureMuter + private Map> backpressureMuter; + private final Clock clock; + private final Duration stateTtl; + private int maxRetryForUnresponsiveNode; + private TimeValue mutePeriod; + + /** + * Constructor + * + * @param client Client to make calls to OpenSearch + * @param xContentRegistry ES named content registry + * @param settings ES settings + * @param clientUtil AD Client utility + * @param clock A UTC clock + * @param stateTtl Max time to keep state in memory + * @param clusterService Cluster service accessor + */ + public NodeStateManager( + Client client, + NamedXContentRegistry xContentRegistry, + Settings settings, + ClientUtil clientUtil, + Clock clock, + Duration stateTtl, + ClusterService clusterService + ) { + this.states = new ConcurrentHashMap<>(); + this.client = client; + this.xContentRegistry = xContentRegistry; + this.clientUtil = clientUtil; + this.backpressureMuter = new ConcurrentHashMap<>(); + this.clock = clock; + this.stateTtl = stateTtl; + this.maxRetryForUnresponsiveNode = MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_RETRY_FOR_UNRESPONSIVE_NODE, it -> { + this.maxRetryForUnresponsiveNode = it; + Iterator> iter = backpressureMuter.values().iterator(); + while (iter.hasNext()) { + Map entry = iter.next(); + entry.values().forEach(v -> v.setMaxRetryForUnresponsiveNode(it)); + } + }); + this.mutePeriod = BACKOFF_MINUTES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(BACKOFF_MINUTES, it -> { + this.mutePeriod = it; + Iterator> iter = backpressureMuter.values().iterator(); + while (iter.hasNext()) { + Map entry = iter.next(); + entry.values().forEach(v -> v.setMutePeriod(it)); + } + }); + } + + /** + * Get Detector config object if present + * @param adID detector Id + * @return the Detecor config object or empty Optional + */ + public Optional getAnomalyDetectorIfPresent(String adID) { + NodeState state = states.get(adID); + return Optional.ofNullable(state).map(NodeState::getDetectorDef); + } + + public void getAnomalyDetector(String adID, ActionListener> listener) { + NodeState state = states.get(adID); + if (state != null && state.getDetectorDef() != null) { + listener.onResponse(Optional.of(state.getDetectorDef())); + } else { + GetRequest request = new GetRequest(CommonName.CONFIG_INDEX, adID); + clientUtil.asyncRequest(request, client::get, onGetDetectorResponse(adID, listener)); + } + } + + private ActionListener onGetDetectorResponse(String adID, ActionListener> listener) { + return ActionListener.wrap(response -> { + if (response == null || !response.isExists()) { + listener.onResponse(Optional.empty()); + return; + } + + String xc = response.getSourceAsString(); + LOG.debug("Fetched anomaly detector: {}", xc); + + try ( + XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, xc) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + AnomalyDetector detector = AnomalyDetector.parse(parser, response.getId()); + // end execution if all features are disabled + if (detector.getEnabledFeatureIds().isEmpty()) { + listener.onFailure(new EndRunException(adID, CommonMessages.ALL_FEATURES_DISABLED_ERR_MSG, true).countedInStats(false)); + return; + } + NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock)); + state.setDetectorDef(detector); + + listener.onResponse(Optional.of(detector)); + } catch (Exception t) { + LOG.error("Fail to parse detector {}", adID); + LOG.error("Stack trace:", t); + listener.onResponse(Optional.empty()); + } + }, listener::onFailure); + } + + /** + * Get a detector's checkpoint and save a flag if we find any so that next time we don't need to do it again + * @param adID the detector's ID + * @param listener listener to handle get request + */ + public void getDetectorCheckpoint(String adID, ActionListener listener) { + NodeState state = states.get(adID); + if (state != null && state.doesCheckpointExists()) { + listener.onResponse(Boolean.TRUE); + return; + } + + GetRequest request = new GetRequest(ADCommonName.CHECKPOINT_INDEX_NAME, SingleStreamModelIdMapper.getRcfModelId(adID, 0)); + + clientUtil.asyncRequest(request, client::get, onGetCheckpointResponse(adID, listener)); + } + + private ActionListener onGetCheckpointResponse(String adID, ActionListener listener) { + return ActionListener.wrap(response -> { + if (response == null || !response.isExists()) { + listener.onResponse(Boolean.FALSE); + } else { + NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock)); + state.setCheckpointExists(true); + listener.onResponse(Boolean.TRUE); + } + }, listener::onFailure); + } + + /** + * Used in delete workflow + * + * @param detectorId detector ID + */ + @Override + public void clear(String detectorId) { + Map routingMap = backpressureMuter.get(detectorId); + if (routingMap != null) { + routingMap.clear(); + backpressureMuter.remove(detectorId); + } + states.remove(detectorId); + } + + /** + * Clean states if it is older than our stateTtl. transportState has to be a + * ConcurrentHashMap otherwise we will have + * java.util.ConcurrentModificationException. + * + */ + @Override + public void maintenance() { + maintenance(states, stateTtl); + } + + public boolean isMuted(String nodeId, String detectorId) { + Map routingMap = backpressureMuter.get(detectorId); + if (routingMap == null || routingMap.isEmpty()) { + return false; + } + BackPressureRouting routing = routingMap.get(nodeId); + return routing != null && routing.isMuted(); + } + + /** + * When we have a unsuccessful call with a node, increment the backpressure counter. + * @param nodeId an ES node's ID + * @param detectorId Detector ID + */ + public void addPressure(String nodeId, String detectorId) { + Map routingMap = backpressureMuter + .computeIfAbsent(detectorId, k -> new HashMap()); + routingMap.computeIfAbsent(nodeId, k -> new BackPressureRouting(k, clock, maxRetryForUnresponsiveNode, mutePeriod)).addPressure(); + } + + /** + * When we have a successful call with a node, clear the backpressure counter. + * @param nodeId an ES node's ID + * @param detectorId Detector ID + */ + public void resetBackpressureCounter(String nodeId, String detectorId) { + Map routingMap = backpressureMuter.get(detectorId); + if (routingMap == null || routingMap.isEmpty()) { + backpressureMuter.remove(detectorId); + return; + } + routingMap.remove(nodeId); + } + + /** + * Check if there is running query on given detector + * @param detector Anomaly Detector + * @return true if given detector has a running query else false + */ + public boolean hasRunningQuery(AnomalyDetector detector) { + return clientUtil.hasRunningQuery(detector); + } + + /** + * Get last error of a detector + * @param adID detector id + * @return last error for the detector + */ + public String getLastDetectionError(String adID) { + return Optional.ofNullable(states.get(adID)).flatMap(state -> state.getLastDetectionError()).orElse(NO_ERROR); + } + + /** + * Set last detection error of a detector + * @param adID detector id + * @param error error, can be null + */ + public void setLastDetectionError(String adID, String error) { + NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock)); + state.setLastDetectionError(error); + } + + /** + * Get a detector's exception. The method has side effect. + * We reset error after calling the method because + * 1) We record a detector's exception in each interval. There is no need + * to record it twice. + * 2) EndRunExceptions can stop job running. We only want to send the same + * signal once for each exception. + * @param adID detector id + * @return the detector's exception + */ + public Optional fetchExceptionAndClear(String adID) { + NodeState state = states.get(adID); + if (state == null) { + return Optional.empty(); + } + + Optional exception = state.getException(); + exception.ifPresent(e -> state.setException(null)); + return exception; + } + + /** + * For single-stream detector, we have one exception per interval. When + * an interval starts, it fetches and clears the exception. + * For HCAD, there can be one exception per entity. To not bloat memory + * with exceptions, we will keep only one exception. An exception has 3 purposes: + * 1) stop detector if nothing else works; + * 2) increment error stats to ticket about high-error domain + * 3) debugging. + * + * For HCAD, we record all entities' exceptions in anomaly results. So 3) + * is covered. As long as we keep one exception among all exceptions, 2) + * is covered. So the only thing we have to pay attention is to keep EndRunException. + * When overriding an exception, EndRunException has priority. + * @param detectorId Detector Id + * @param e Exception to set + */ + public void setException(String detectorId, Exception e) { + if (e == null || Strings.isEmpty(detectorId)) { + return; + } + NodeState state = states.computeIfAbsent(detectorId, d -> new NodeState(detectorId, clock)); + Optional exception = state.getException(); + if (exception.isPresent()) { + Exception higherPriorityException = ExceptionUtil.selectHigherPriorityException(e, exception.get()); + if (higherPriorityException != e) { + return; + } + } + + state.setException(e); + } + + /** + * Whether last cold start for the detector is running + * @param adID detector ID + * @return running or not + */ + public boolean isColdStartRunning(String adID) { + NodeState state = states.get(adID); + if (state != null) { + return state.isColdStartRunning(); + } + + return false; + } + + /** + * Mark the cold start status of the detector + * @param adID detector ID + * @return a callback when cold start is done + */ + public Releasable markColdStartRunning(String adID) { + NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock)); + state.setColdStartRunning(true); + return () -> { + NodeState nodeState = states.get(adID); + if (nodeState != null) { + nodeState.setColdStartRunning(false); + } + }; + } + + public void getAnomalyDetectorJob(String adID, ActionListener> listener) { + NodeState state = states.get(adID); + if (state != null && state.getDetectorJob() != null) { + listener.onResponse(Optional.of(state.getDetectorJob())); + } else { + GetRequest request = new GetRequest(CommonName.JOB_INDEX, adID); + clientUtil.asyncRequest(request, client::get, onGetDetectorJobResponse(adID, listener)); + } + } + + private ActionListener onGetDetectorJobResponse(String adID, ActionListener> listener) { + return ActionListener.wrap(response -> { + if (response == null || !response.isExists()) { + listener.onResponse(Optional.empty()); + return; + } + + String xc = response.getSourceAsString(); + LOG.debug("Fetched anomaly detector: {}", xc); + + try ( + XContentParser parser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, response.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + AnomalyDetectorJob job = AnomalyDetectorJob.parse(parser); + NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock)); + state.setDetectorJob(job); + + listener.onResponse(Optional.of(job)); + } catch (Exception t) { + LOG.error(new ParameterizedMessage("Fail to parse job {}", adID), t); + listener.onResponse(Optional.empty()); + } + }, listener::onFailure); + } +} diff --git a/src/main/java/org/opensearch/ad/ProfileUtil.java-e b/src/main/java/org/opensearch/ad/ProfileUtil.java-e new file mode 100644 index 000000000..8afd98dc3 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ProfileUtil.java-e @@ -0,0 +1,65 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.client.Client; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.ExistsQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.constant.CommonName; + +public class ProfileUtil { + /** + * Create search request to check if we have at least 1 anomaly score larger than 0 after AD job enabled time. + * Note this function is only meant to check for status of real time analysis. + * + * @param detectorId detector id + * @param enabledTime the time when AD job is enabled in milliseconds + * @return the search request + */ + private static SearchRequest createRealtimeInittedEverRequest(String detectorId, long enabledTime, String resultIndex) { + BoolQueryBuilder filterQuery = new BoolQueryBuilder(); + filterQuery.filter(QueryBuilders.termQuery(AnomalyResult.DETECTOR_ID_FIELD, detectorId)); + filterQuery.filter(QueryBuilders.rangeQuery(CommonName.EXECUTION_END_TIME_FIELD).gte(enabledTime)); + filterQuery.filter(QueryBuilders.rangeQuery(AnomalyResult.ANOMALY_SCORE_FIELD).gt(0)); + // Historical analysis result also stored in result index, which has non-null task_id. + // For realtime detection result, we should filter task_id == null + ExistsQueryBuilder taskIdExistsFilter = QueryBuilders.existsQuery(CommonName.TASK_ID_FIELD); + filterQuery.mustNot(taskIdExistsFilter); + + SearchSourceBuilder source = new SearchSourceBuilder().query(filterQuery).size(1); + + SearchRequest request = new SearchRequest(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS); + request.source(source); + if (resultIndex != null) { + request.indices(resultIndex); + } + return request; + } + + public static void confirmDetectorRealtimeInitStatus( + AnomalyDetector detector, + long enabledTime, + Client client, + ActionListener listener + ) { + SearchRequest searchLatestResult = createRealtimeInittedEverRequest(detector.getId(), enabledTime, detector.getCustomResultIndex()); + client.search(searchLatestResult, listener); + } +} diff --git a/src/main/java/org/opensearch/ad/breaker/ADCircuitBreakerService.java-e b/src/main/java/org/opensearch/ad/breaker/ADCircuitBreakerService.java-e new file mode 100644 index 000000000..9c9ab5b34 --- /dev/null +++ b/src/main/java/org/opensearch/ad/breaker/ADCircuitBreakerService.java-e @@ -0,0 +1,91 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.breaker; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.monitor.jvm.JvmService; + +/** + * Class {@code ADCircuitBreakerService} provide storing, retrieving circuit breakers functions. + * + * This service registers internal system breakers and provide API for users to register their own breakers. + */ +public class ADCircuitBreakerService { + + private final ConcurrentMap breakers = new ConcurrentHashMap<>(); + private final JvmService jvmService; + + private static final Logger logger = LogManager.getLogger(ADCircuitBreakerService.class); + + /** + * Constructor. + * + * @param jvmService jvm info + */ + public ADCircuitBreakerService(JvmService jvmService) { + this.jvmService = jvmService; + } + + public void registerBreaker(String name, CircuitBreaker breaker) { + breakers.putIfAbsent(name, breaker); + } + + public void unregisterBreaker(String name) { + if (name == null) { + return; + } + + breakers.remove(name); + } + + public void clearBreakers() { + breakers.clear(); + } + + public CircuitBreaker getBreaker(String name) { + return breakers.get(name); + } + + /** + * Initialize circuit breaker service. + * + * Register memory breaker by default. + * + * @return ADCircuitBreakerService + */ + public ADCircuitBreakerService init() { + // Register memory circuit breaker + registerBreaker(BreakerName.MEM.getName(), new MemoryCircuitBreaker(this.jvmService)); + logger.info("Registered memory breaker."); + + return this; + } + + public Boolean isOpen() { + if (!ADEnabledSetting.isADBreakerEnabled()) { + return false; + } + + for (CircuitBreaker breaker : breakers.values()) { + if (breaker.isOpen()) { + return true; + } + } + + return false; + } +} diff --git a/src/main/java/org/opensearch/ad/breaker/BreakerName.java-e b/src/main/java/org/opensearch/ad/breaker/BreakerName.java-e new file mode 100644 index 000000000..a6405cf1f --- /dev/null +++ b/src/main/java/org/opensearch/ad/breaker/BreakerName.java-e @@ -0,0 +1,28 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.breaker; + +public enum BreakerName { + + MEM("memory"), + CPU("cpu"); + + private String name; + + BreakerName(String name) { + this.name = name; + } + + public String getName() { + return name; + } +} diff --git a/src/main/java/org/opensearch/ad/breaker/CircuitBreaker.java-e b/src/main/java/org/opensearch/ad/breaker/CircuitBreaker.java-e new file mode 100644 index 000000000..2825d2f98 --- /dev/null +++ b/src/main/java/org/opensearch/ad/breaker/CircuitBreaker.java-e @@ -0,0 +1,22 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.breaker; + +/** + * An interface for circuit breaker. + * + * We use circuit breaker to protect a certain system resource like memory, cpu etc. + */ +public interface CircuitBreaker { + + boolean isOpen(); +} diff --git a/src/main/java/org/opensearch/ad/breaker/MemoryCircuitBreaker.java-e b/src/main/java/org/opensearch/ad/breaker/MemoryCircuitBreaker.java-e new file mode 100644 index 000000000..c4628c639 --- /dev/null +++ b/src/main/java/org/opensearch/ad/breaker/MemoryCircuitBreaker.java-e @@ -0,0 +1,38 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.breaker; + +import org.opensearch.monitor.jvm.JvmService; + +/** + * A circuit breaker for memory usage. + */ +public class MemoryCircuitBreaker extends ThresholdCircuitBreaker { + + public static final short DEFAULT_JVM_HEAP_USAGE_THRESHOLD = 85; + private final JvmService jvmService; + + public MemoryCircuitBreaker(JvmService jvmService) { + super(DEFAULT_JVM_HEAP_USAGE_THRESHOLD); + this.jvmService = jvmService; + } + + public MemoryCircuitBreaker(short threshold, JvmService jvmService) { + super(threshold); + this.jvmService = jvmService; + } + + @Override + public boolean isOpen() { + return jvmService.stats().getMem().getHeapUsedPercent() > this.getThreshold(); + } +} diff --git a/src/main/java/org/opensearch/ad/breaker/ThresholdCircuitBreaker.java-e b/src/main/java/org/opensearch/ad/breaker/ThresholdCircuitBreaker.java-e new file mode 100644 index 000000000..30959b0c4 --- /dev/null +++ b/src/main/java/org/opensearch/ad/breaker/ThresholdCircuitBreaker.java-e @@ -0,0 +1,32 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.breaker; + +/** + * An abstract class for all breakers with threshold. + * @param data type of threshold + */ +public abstract class ThresholdCircuitBreaker implements CircuitBreaker { + + private T threshold; + + public ThresholdCircuitBreaker(T threshold) { + this.threshold = threshold; + } + + public T getThreshold() { + return threshold; + } + + @Override + public abstract boolean isOpen(); +} diff --git a/src/main/java/org/opensearch/ad/caching/CacheBuffer.java-e b/src/main/java/org/opensearch/ad/caching/CacheBuffer.java-e new file mode 100644 index 000000000..d9ec0143d --- /dev/null +++ b/src/main/java/org/opensearch/ad/caching/CacheBuffer.java-e @@ -0,0 +1,552 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.caching; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; + +import org.apache.commons.lang.builder.EqualsBuilder; +import org.apache.commons.lang.builder.HashCodeBuilder; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ad.ExpiringState; +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.MemoryTracker.Origin; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.model.InitProgressProfile; +import org.opensearch.ad.ratelimit.CheckpointMaintainRequest; +import org.opensearch.ad.ratelimit.CheckpointMaintainWorker; +import org.opensearch.ad.ratelimit.CheckpointWriteWorker; +import org.opensearch.ad.ratelimit.RequestPriority; +import org.opensearch.ad.util.DateUtils; + +/** + * We use a layered cache to manage active entities’ states. We have a two-level + * cache that stores active entity states in each node. Each detector has its + * dedicated cache that stores ten (dynamically adjustable) entities’ states per + * node. A detector’s hottest entities load their states in the dedicated cache. + * If less than 10 entities use the dedicated cache, the secondary cache can use + * the rest of the free memory available to AD. The secondary cache is a shared + * memory among all detectors for the long tail. The shared cache size is 10% + * heap minus all of the dedicated cache consumed by single-entity and multi-entity + * detectors. The shared cache’s size shrinks as the dedicated cache is filled + * up or more detectors are started. + * + * Implementation-wise, both dedicated cache and shared cache are stored in items + * and minimumCapacity controls the boundary. If items size is equals to or less + * than minimumCapacity, consider items as dedicated cache; otherwise, consider + * top minimumCapacity active entities (last X entities in priorityList) as in dedicated + * cache and all others in shared cache. + */ +public class CacheBuffer implements ExpiringState { + private static final Logger LOG = LogManager.getLogger(CacheBuffer.class); + + // max entities to track per detector + private final int MAX_TRACKING_ENTITIES = 1000000; + + // the reserved cache size. So no matter how many entities there are, we will + // keep the size for minimum capacity entities + private int minimumCapacity; + + // key is model id + private final ConcurrentHashMap> items; + // memory consumption per entity + private final long memoryConsumptionPerEntity; + private final MemoryTracker memoryTracker; + private final Duration modelTtl; + private final String detectorId; + private Instant lastUsedTime; + private long reservedBytes; + private final PriorityTracker priorityTracker; + private final Clock clock; + private final CheckpointWriteWorker checkpointWriteQueue; + private final CheckpointMaintainWorker checkpointMaintainQueue; + private int checkpointIntervalHrs; + + public CacheBuffer( + int minimumCapacity, + long intervalSecs, + long memoryConsumptionPerEntity, + MemoryTracker memoryTracker, + Clock clock, + Duration modelTtl, + String detectorId, + CheckpointWriteWorker checkpointWriteQueue, + CheckpointMaintainWorker checkpointMaintainQueue, + int checkpointIntervalHrs + ) { + this.memoryConsumptionPerEntity = memoryConsumptionPerEntity; + setMinimumCapacity(minimumCapacity); + + this.items = new ConcurrentHashMap<>(); + this.memoryTracker = memoryTracker; + + this.modelTtl = modelTtl; + this.detectorId = detectorId; + this.lastUsedTime = clock.instant(); + + this.clock = clock; + this.priorityTracker = new PriorityTracker(clock, intervalSecs, clock.instant().getEpochSecond(), MAX_TRACKING_ENTITIES); + this.checkpointWriteQueue = checkpointWriteQueue; + this.checkpointMaintainQueue = checkpointMaintainQueue; + setCheckpointIntervalHrs(checkpointIntervalHrs); + } + + /** + * Update step at period t_k: + * new priority = old priority + log(1+e^{\log(g(t_k-L))-old priority}) where g(n) = e^{0.125n}, + * and n is the period. + * @param entityModelId model Id + */ + private void update(String entityModelId) { + priorityTracker.updatePriority(entityModelId); + + Instant now = clock.instant(); + items.get(entityModelId).setLastUsedTime(now); + lastUsedTime = now; + } + + /** + * Insert the model state associated with a model Id to the cache + * @param entityModelId the model Id + * @param value the ModelState + */ + public void put(String entityModelId, ModelState value) { + // race conditions can happen between the put and one of the following operations: + // remove: not a problem as it is unlikely we are removing and putting the same thing + // maintenance: not a problem as we are unlikely to maintain an entry that's not + // already in the cache + // clear: not a problem as we are releasing memory in MemoryTracker. + // The newly added one loses references and soon GC will collect it. + // We have memory tracking correction to fix incorrect memory usage record. + // put from other threads: not a problem as the entry is associated with + // entityModelId and our put is idempotent + put(entityModelId, value, value.getPriority()); + } + + /** + * Insert the model state associated with a model Id to the cache. Update priority. + * @param entityModelId the model Id + * @param value the ModelState + * @param priority the priority + */ + private void put(String entityModelId, ModelState value, float priority) { + ModelState contentNode = items.get(entityModelId); + if (contentNode == null) { + priorityTracker.addPriority(entityModelId, priority); + items.put(entityModelId, value); + Instant now = clock.instant(); + value.setLastUsedTime(now); + lastUsedTime = now; + // shared cache empty means we are consuming reserved cache. + // Since we have already considered them while allocating CacheBuffer, + // skip bookkeeping. + if (!sharedCacheEmpty()) { + memoryTracker.consumeMemory(memoryConsumptionPerEntity, false, Origin.HC_DETECTOR); + } + } else { + update(entityModelId); + items.put(entityModelId, value); + } + } + + /** + * Retrieve the ModelState associated with the model Id or null if the CacheBuffer + * contains no mapping for the model Id + * @param key the model Id + * @return the Model state to which the specified model Id is mapped, or null + * if this CacheBuffer contains no mapping for the model Id + */ + public ModelState get(String key) { + // We can get an item that is to be removed soon due to race condition. + // This is acceptable as it won't cause any corruption and exception. + // And this item is used for scoring one last time. + ModelState node = items.get(key); + if (node == null) { + return null; + } + update(key); + return node; + } + + /** + * Retrieve the ModelState associated with the model Id or null if the CacheBuffer + * contains no mapping for the model Id. Compared to get method, the method won't + * increment entity priority. Used in cache buffer maintenance. + * + * @param key the model Id + * @return the Model state to which the specified model Id is mapped, or null + * if this CacheBuffer contains no mapping for the model Id + */ + public ModelState getWithoutUpdatePriority(String key) { + // We can get an item that is to be removed soon due to race condition. + // This is acceptable as it won't cause any corruption and exception. + // And this item is used for scoring one last time. + ModelState node = items.get(key); + if (node == null) { + return null; + } + return node; + } + + /** + * + * @return whether there is one item that can be removed from shared cache + */ + public boolean canRemove() { + return !items.isEmpty() && items.size() > minimumCapacity; + } + + /** + * remove the smallest priority item. + * @return the associated ModelState associated with the key, or null if there + * is no associated ModelState for the key + */ + public ModelState remove() { + // race conditions can happen between the put and one of the following operations: + // remove from other threads: not a problem. If they remove the same item, + // our method is idempotent. If they remove two different items, + // they don't impact each other. + // maintenance: not a problem as all of the data structures are concurrent. + // Two threads removing the same entry is not a problem. + // clear: not a problem as we are releasing memory in MemoryTracker. + // The removed one loses references and soon GC will collect it. + // We have memory tracking correction to fix incorrect memory usage record. + // put: not a problem as it is unlikely we are removing and putting the same thing + Optional key = priorityTracker.getMinimumPriorityEntityId(); + if (key.isPresent()) { + return remove(key.get()); + } + return null; + } + + /** + * Remove everything associated with the key and make a checkpoint. + * + * @param keyToRemove The key to remove + * @return the associated ModelState associated with the key, or null if there + * is no associated ModelState for the key + */ + public ModelState remove(String keyToRemove) { + return remove(keyToRemove, true); + } + + /** + * Remove everything associated with the key and make a checkpoint if input specified so. + * + * @param keyToRemove The key to remove + * @param saveCheckpoint Whether saving checkpoint or not + * @return the associated ModelState associated with the key, or null if there + * is no associated ModelState for the key + */ + public ModelState remove(String keyToRemove, boolean saveCheckpoint) { + priorityTracker.removePriority(keyToRemove); + + // if shared cache is empty, we are using reserved memory + boolean reserved = sharedCacheEmpty(); + + ModelState valueRemoved = items.remove(keyToRemove); + + if (valueRemoved != null) { + if (!reserved) { + // release in shared memory + memoryTracker.releaseMemory(memoryConsumptionPerEntity, false, Origin.HC_DETECTOR); + } + + EntityModel modelRemoved = valueRemoved.getModel(); + if (modelRemoved != null) { + if (saveCheckpoint) { + // null model has only samples. For null model we save a checkpoint + // regardless of last checkpoint time. whether If we don't save, + // we throw the new samples and might never be able to initialize the model + boolean isNullModel = !modelRemoved.getTrcf().isPresent(); + checkpointWriteQueue.write(valueRemoved, isNullModel, RequestPriority.MEDIUM); + } + + modelRemoved.clear(); + } + } + + return valueRemoved; + } + + /** + * @return whether dedicated cache is available or not + */ + public boolean dedicatedCacheAvailable() { + return items.size() < minimumCapacity; + } + + /** + * @return whether shared cache is empty or not + */ + public boolean sharedCacheEmpty() { + return items.size() <= minimumCapacity; + } + + /** + * + * @return the estimated number of bytes per entity state + */ + public long getMemoryConsumptionPerEntity() { + return memoryConsumptionPerEntity; + } + + /** + * + * If the cache is not full, check if some other items can replace internal entities + * within the same detector. + * + * @param priority another entity's priority + * @return whether one entity can be replaced by another entity with a certain priority + */ + public boolean canReplaceWithinDetector(float priority) { + if (items.isEmpty()) { + return false; + } + Optional> minPriorityItem = priorityTracker.getMinimumPriority(); + return minPriorityItem.isPresent() && priority > minPriorityItem.get().getValue(); + } + + /** + * Replace the smallest priority entity with the input entity + * @param entityModelId the Model Id + * @param value the model State + * @return the associated ModelState associated with the key, or null if there + * is no associated ModelState for the key + */ + public ModelState replace(String entityModelId, ModelState value) { + ModelState replaced = remove(); + put(entityModelId, value); + return replaced; + } + + /** + * Remove expired state and save checkpoints of existing states + * @return removed states + */ + public List> maintenance() { + List modelsToSave = new ArrayList<>(); + List> removedStates = new ArrayList<>(); + Instant now = clock.instant(); + int currentHour = DateUtils.getUTCHourOfDay(now); + int currentSlot = currentHour % checkpointIntervalHrs; + items.entrySet().stream().forEach(entry -> { + String entityModelId = entry.getKey(); + try { + ModelState modelState = entry.getValue(); + + if (modelState.getLastUsedTime().plus(modelTtl).isBefore(now)) { + // race conditions can happen between the put and one of the following operations: + // remove: not a problem as all of the data structures are concurrent. + // Two threads removing the same entry is not a problem. + // clear: not a problem as we are releasing memory in MemoryTracker. + // The removed one loses references and soon GC will collect it. + // We have memory tracking correction to fix incorrect memory usage record. + // put: not a problem as we are unlikely to maintain an entry that's not + // already in the cache + // remove method saves checkpoint as well + removedStates.add(remove(entityModelId)); + } else if (Math.abs(entityModelId.hashCode()) % checkpointIntervalHrs == currentSlot) { + // checkpoint is relatively big compared to other queued requests + // Evens out the resource usage more fairly across a large maintenance window + // by adding saving requests to CheckpointMaintainWorker. + // + // Background: + // We will save a checkpoint when + // + // (a)removing the model from cache. + // (b) cold start + // (c) no complete model only a few samples. If we don't save new samples, + // we will never be able to have enough samples for a trained mode. + // (d) periodically save in case of exceptions. + // + // This branch is doing d). Previously, I will do it every hour for all + // in-cache models. Consider we are moving to 1M entities, this will bring + // the cluster in a heavy payload every hour. That's why I am doing it randomly + // (expected 6 hours for each checkpoint statistically). + // + // I am doing it random since maintaining a state of which one has been saved + // and which one hasn't are not cheap. Also, the models in the cache can be + // dynamically changing. Will have to maintain the state in the removing logic. + // Random is a lazy way to deal with this as it is stateless and statistically sound. + // + // If a checkpoint does not fall into the 6-hour bucket in a particular scenario, the model + // is stale (i.e., we don't recover from the freshest model in disaster.). + // + // All in all, randomness is mostly due to performance and easy maintenance. + modelsToSave + .add( + new CheckpointMaintainRequest( + // the request expires when the next maintainance starts + System.currentTimeMillis() + modelTtl.toMillis(), + detectorId, + RequestPriority.LOW, + entityModelId + ) + ); + } + + } catch (Exception e) { + LOG.warn("Failed to finish maintenance for model id " + entityModelId, e); + } + }); + + checkpointMaintainQueue.putAll(modelsToSave); + return removedStates; + } + + /** + * + * @return the number of active entities + */ + public int getActiveEntities() { + return items.size(); + } + + /** + * + * @param entityModelId Model Id + * @return Whether the model is active or not + */ + public boolean isActive(String entityModelId) { + return items.containsKey(entityModelId); + } + + /** + * + * @param entityModelId Model Id + * @return Last used time of the model + */ + public long getLastUsedTime(String entityModelId) { + ModelState state = items.get(entityModelId); + if (state != null) { + return state.getLastUsedTime().toEpochMilli(); + } + return -1; + } + + /** + * + * @param entityModelId entity Id + * @return Get the model of an entity + */ + public Optional getModel(String entityModelId) { + return Optional.of(items).map(map -> map.get(entityModelId)).map(state -> state.getModel()); + } + + /** + * Clear associated memory. Used when we are removing an detector. + */ + public void clear() { + // race conditions can happen between the put and remove/maintenance/put: + // not a problem as we are releasing memory in MemoryTracker. + // The newly added one loses references and soon GC will collect it. + // We have memory tracking correction to fix incorrect memory usage record. + memoryTracker.releaseMemory(getReservedBytes(), true, Origin.HC_DETECTOR); + if (!sharedCacheEmpty()) { + memoryTracker.releaseMemory(getBytesInSharedCache(), false, Origin.HC_DETECTOR); + } + items.clear(); + priorityTracker.clearPriority(); + } + + /** + * + * @return reserved bytes by the CacheBuffer + */ + public long getReservedBytes() { + return reservedBytes; + } + + /** + * + * @return bytes consumed in the shared cache by the CacheBuffer + */ + public long getBytesInSharedCache() { + int sharedCacheEntries = items.size() - minimumCapacity; + if (sharedCacheEntries > 0) { + return memoryConsumptionPerEntity * sharedCacheEntries; + } + return 0; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + if (obj instanceof InitProgressProfile) { + CacheBuffer other = (CacheBuffer) obj; + + EqualsBuilder equalsBuilder = new EqualsBuilder(); + equalsBuilder.append(detectorId, other.detectorId); + + return equalsBuilder.isEquals(); + } + return false; + } + + @Override + public int hashCode() { + return new HashCodeBuilder().append(detectorId).toHashCode(); + } + + @Override + public boolean expired(Duration stateTtl) { + return expired(lastUsedTime, stateTtl, clock.instant()); + } + + public String getId() { + return detectorId; + } + + public List> getAllModels() { + return items.values().stream().collect(Collectors.toList()); + } + + public PriorityTracker getPriorityTracker() { + return priorityTracker; + } + + public void setMinimumCapacity(int minimumCapacity) { + if (minimumCapacity < 0) { + throw new IllegalArgumentException("minimum capacity should be larger than or equal 0"); + } + this.minimumCapacity = minimumCapacity; + this.reservedBytes = memoryConsumptionPerEntity * minimumCapacity; + } + + public void setCheckpointIntervalHrs(int checkpointIntervalHrs) { + this.checkpointIntervalHrs = checkpointIntervalHrs; + // 0 can cause java.lang.ArithmeticException: / by zero + // negative value is meaningless + if (checkpointIntervalHrs <= 0) { + this.checkpointIntervalHrs = 1; + } + } + + public int getCheckpointIntervalHrs() { + return checkpointIntervalHrs; + } +} diff --git a/src/main/java/org/opensearch/ad/caching/CacheProvider.java-e b/src/main/java/org/opensearch/ad/caching/CacheProvider.java-e new file mode 100644 index 000000000..ab8fd191c --- /dev/null +++ b/src/main/java/org/opensearch/ad/caching/CacheProvider.java-e @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.caching; + +import org.opensearch.common.inject.Provider; + +/** + * A wrapper to call concrete implementation of caching. Used in transport + * action. Don't use interface because transport action handler constructor + * requires a concrete class as input. + * + */ +public class CacheProvider implements Provider { + private EntityCache cache; + + public CacheProvider() { + + } + + @Override + public EntityCache get() { + return cache; + } + + public void set(EntityCache cache) { + this.cache = cache; + } +} diff --git a/src/main/java/org/opensearch/ad/caching/DoorKeeper.java-e b/src/main/java/org/opensearch/ad/caching/DoorKeeper.java-e new file mode 100644 index 000000000..96a18d8f6 --- /dev/null +++ b/src/main/java/org/opensearch/ad/caching/DoorKeeper.java-e @@ -0,0 +1,82 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.caching; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ad.ExpiringState; +import org.opensearch.ad.MaintenanceState; + +import com.google.common.base.Charsets; +import com.google.common.hash.BloomFilter; +import com.google.common.hash.Funnels; + +/** + * A bloom filter with regular reset. + * + * Reference: https://arxiv.org/abs/1512.00727 + * + */ +public class DoorKeeper implements MaintenanceState, ExpiringState { + private final Logger LOG = LogManager.getLogger(DoorKeeper.class); + // stores entity's model id + private BloomFilter bloomFilter; + // the number of expected insertions to the constructed BloomFilter; must be positive + private final long expectedInsertions; + // the desired false positive probability (must be positive and less than 1.0) + private final double fpp; + private Instant lastMaintenanceTime; + private final Duration resetInterval; + private final Clock clock; + private Instant lastAccessTime; + + public DoorKeeper(long expectedInsertions, double fpp, Duration resetInterval, Clock clock) { + this.expectedInsertions = expectedInsertions; + this.fpp = fpp; + this.resetInterval = resetInterval; + this.clock = clock; + this.lastAccessTime = clock.instant(); + maintenance(); + } + + public boolean mightContain(String modelId) { + this.lastAccessTime = clock.instant(); + return bloomFilter.mightContain(modelId); + } + + public boolean put(String modelId) { + this.lastAccessTime = clock.instant(); + return bloomFilter.put(modelId); + } + + /** + * We reset the bloom filter when bloom filter is null or it is state ttl is reached + */ + @Override + public void maintenance() { + if (bloomFilter == null || lastMaintenanceTime.plus(resetInterval).isBefore(clock.instant())) { + LOG.debug("maintaining for doorkeeper"); + bloomFilter = BloomFilter.create(Funnels.stringFunnel(Charsets.US_ASCII), expectedInsertions, fpp); + lastMaintenanceTime = clock.instant(); + } + } + + @Override + public boolean expired(Duration stateTtl) { + // ignore stateTtl since we have customized resetInterval + return expired(lastAccessTime, resetInterval, clock.instant()); + } +} diff --git a/src/main/java/org/opensearch/ad/caching/EntityCache.java-e b/src/main/java/org/opensearch/ad/caching/EntityCache.java-e new file mode 100644 index 000000000..0a6a303d6 --- /dev/null +++ b/src/main/java/org/opensearch/ad/caching/EntityCache.java-e @@ -0,0 +1,157 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.caching; + +import java.util.Collection; +import java.util.List; +import java.util.Optional; + +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.ad.CleanState; +import org.opensearch.ad.DetectorModelSize; +import org.opensearch.ad.MaintenanceState; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.ModelProfile; +import org.opensearch.timeseries.model.Entity; + +public interface EntityCache extends MaintenanceState, CleanState, DetectorModelSize { + /** + * Get the ModelState associated with the entity. May or may not load the + * ModelState depending on the underlying cache's eviction policy. + * + * @param modelId Model Id + * @param detector Detector config object + * @return the ModelState associated with the model or null if no cached item + * for the entity + */ + ModelState get(String modelId, AnomalyDetector detector); + + /** + * Get the number of active entities of a detector + * @param detector Detector Id + * @return The number of active entities + */ + int getActiveEntities(String detector); + + /** + * + * @return total active entities in the cache + */ + int getTotalActiveEntities(); + + /** + * Whether an entity is active or not + * @param detectorId The Id of the detector that an entity belongs to + * @param entityModelId Entity model Id + * @return Whether an entity is active or not + */ + boolean isActive(String detectorId, String entityModelId); + + /** + * Get total updates of detector's most active entity's RCF model. + * + * @param detectorId detector id + * @return RCF model total updates of most active entity. + */ + long getTotalUpdates(String detectorId); + + /** + * Get RCF model total updates of specific entity + * + * @param detectorId detector id + * @param entityModelId entity model id + * @return RCF model total updates of specific entity. + */ + long getTotalUpdates(String detectorId, String entityModelId); + + /** + * Gets modelStates of all model hosted on a node + * + * @return list of modelStates + */ + List> getAllModels(); + + /** + * Return when the last active time of an entity's state. + * + * If the entity's state is active in the cache, the value indicates when the cache + * is lastly accessed (get/put). If the entity's state is inactive in the cache, + * the value indicates when the cache state is created or when the entity is evicted + * from active entity cache. + * + * @param detectorId The Id of the detector that an entity belongs to + * @param entityModelId Entity's Model Id + * @return if the entity is in the cache, return the timestamp in epoch + * milliseconds when the entity's state is lastly used. Otherwise, return -1. + */ + long getLastActiveMs(String detectorId, String entityModelId); + + /** + * Release memory when memory circuit breaker is open + */ + void releaseMemoryForOpenCircuitBreaker(); + + /** + * Select candidate entities for which we can load models + * @param cacheMissEntities Cache miss entities + * @param detectorId Detector Id + * @param detector Detector object + * @return A list of entities that are admitted into the cache as a result of the + * update and the left-over entities + */ + Pair, List> selectUpdateCandidate( + Collection cacheMissEntities, + String detectorId, + AnomalyDetector detector + ); + + /** + * + * @param detector Detector config + * @param toUpdate Model state candidate + * @return if we can host the given model state + */ + boolean hostIfPossible(AnomalyDetector detector, ModelState toUpdate); + + /** + * + * @param detectorId Detector Id + * @return a detector's model information + */ + List getAllModelProfile(String detectorId); + + /** + * Gets an entity's model sizes + * + * @param detectorId Detector Id + * @param entityModelId Entity's model Id + * @return the entity's memory size + */ + Optional getModelProfile(String detectorId, String entityModelId); + + /** + * Get a model state without incurring priority update. Used in maintenance. + * @param detectorId Detector Id + * @param modelId Model Id + * @return Model state + */ + Optional> getForMaintainance(String detectorId, String modelId); + + /** + * Remove entity model from active entity buffer and delete checkpoint. Used to clean corrupted model. + * @param detectorId Detector Id + * @param entityModelId Model Id + */ + void removeEntityModel(String detectorId, String entityModelId); +} diff --git a/src/main/java/org/opensearch/ad/caching/PriorityCache.java b/src/main/java/org/opensearch/ad/caching/PriorityCache.java index 014c97c00..be8c05397 100644 --- a/src/main/java/org/opensearch/ad/caching/PriorityCache.java +++ b/src/main/java/org/opensearch/ad/caching/PriorityCache.java @@ -39,7 +39,6 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.action.ActionListener; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.MemoryTracker; import org.opensearch.ad.MemoryTracker.Origin; import org.opensearch.ad.ml.CheckpointDao; @@ -58,6 +57,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.Strings; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonMessages; @@ -548,7 +548,7 @@ private Triple canReplaceInSharedCache(CacheBuffer o private void tryClearUpMemory() { try { if (maintenanceLock.tryLock()) { - threadPool.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME).execute(() -> clearMemory()); + threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME).execute(() -> clearMemory()); } else { threadPool.schedule(() -> { try { @@ -556,7 +556,7 @@ private void tryClearUpMemory() { } catch (Exception e) { LOG.error("Fail to clear up memory taken by CacheBuffer. Will retry during maintenance."); } - }, new TimeValue(random.nextInt(90), TimeUnit.SECONDS), AnomalyDetectorPlugin.AD_THREAD_POOL_NAME); + }, new TimeValue(random.nextInt(90), TimeUnit.SECONDS), TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME); } } finally { if (maintenanceLock.isHeldByCurrentThread()) { diff --git a/src/main/java/org/opensearch/ad/caching/PriorityCache.java-e b/src/main/java/org/opensearch/ad/caching/PriorityCache.java-e new file mode 100644 index 000000000..be8c05397 --- /dev/null +++ b/src/main/java/org/opensearch/ad/caching/PriorityCache.java-e @@ -0,0 +1,971 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.caching; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.DEDICATED_CACHE_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.PriorityQueue; +import java.util.Random; +import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantLock; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.lang3.tuple.Triple; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.MemoryTracker.Origin; +import org.opensearch.ad.ml.CheckpointDao; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelManager.ModelType; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.ModelProfile; +import org.opensearch.ad.ratelimit.CheckpointMaintainWorker; +import org.opensearch.ad.ratelimit.CheckpointWriteWorker; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.ad.util.DateUtils; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.Strings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; + +public class PriorityCache implements EntityCache { + private final Logger LOG = LogManager.getLogger(PriorityCache.class); + + // detector id -> CacheBuffer, weight based + private final Map activeEnities; + private final CheckpointDao checkpointDao; + private volatile int dedicatedCacheSize; + // LRU Cache, key is model id + private Cache> inActiveEntities; + private final MemoryTracker memoryTracker; + private final ReentrantLock maintenanceLock; + private final int numberOfTrees; + private final Clock clock; + private final Duration modelTtl; + // A bloom filter placed in front of inactive entity cache to + // filter out unpopular items that are not likely to appear more + // than once. Key is detector id + private Map doorKeepers; + private ThreadPool threadPool; + private Random random; + private CheckpointWriteWorker checkpointWriteQueue; + // iterating through all of inactive entities is heavy. We don't want to do + // it again and again for no obvious benefits. + private Instant lastInActiveEntityMaintenance; + protected int maintenanceFreqConstant; + private CheckpointMaintainWorker checkpointMaintainQueue; + private int checkpointIntervalHrs; + + public PriorityCache( + CheckpointDao checkpointDao, + int dedicatedCacheSize, + Setting checkpointTtl, + int maxInactiveStates, + MemoryTracker memoryTracker, + int numberOfTrees, + Clock clock, + ClusterService clusterService, + Duration modelTtl, + ThreadPool threadPool, + CheckpointWriteWorker checkpointWriteQueue, + int maintenanceFreqConstant, + CheckpointMaintainWorker checkpointMaintainQueue, + Settings settings, + Setting checkpointSavingFreq + ) { + this.checkpointDao = checkpointDao; + + this.activeEnities = new ConcurrentHashMap<>(); + this.dedicatedCacheSize = dedicatedCacheSize; + clusterService.getClusterSettings().addSettingsUpdateConsumer(DEDICATED_CACHE_SIZE, (it) -> { + this.dedicatedCacheSize = it; + this.setDedicatedCacheSizeListener(); + this.tryClearUpMemory(); + }, this::validateDedicatedCacheSize); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MODEL_MAX_SIZE_PERCENTAGE, it -> this.tryClearUpMemory()); + + this.memoryTracker = memoryTracker; + this.maintenanceLock = new ReentrantLock(); + this.numberOfTrees = numberOfTrees; + this.clock = clock; + this.modelTtl = modelTtl; + this.doorKeepers = new ConcurrentHashMap<>(); + + Duration inactiveEntityTtl = DateUtils.toDuration(checkpointTtl.get(settings)); + + this.inActiveEntities = createInactiveCache(inactiveEntityTtl, maxInactiveStates); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer( + checkpointTtl, + it -> { this.inActiveEntities = createInactiveCache(DateUtils.toDuration(it), maxInactiveStates); } + ); + + this.threadPool = threadPool; + this.random = new Random(42); + this.checkpointWriteQueue = checkpointWriteQueue; + this.lastInActiveEntityMaintenance = Instant.MIN; + this.maintenanceFreqConstant = maintenanceFreqConstant; + this.checkpointMaintainQueue = checkpointMaintainQueue; + + this.checkpointIntervalHrs = DateUtils.toDuration(checkpointSavingFreq.get(settings)).toHoursPart(); + clusterService.getClusterSettings().addSettingsUpdateConsumer(checkpointSavingFreq, it -> { + this.checkpointIntervalHrs = DateUtils.toDuration(it).toHoursPart(); + this.setCheckpointFreqListener(); + }); + } + + @Override + public ModelState get(String modelId, AnomalyDetector detector) { + String detectorId = detector.getId(); + CacheBuffer buffer = computeBufferIfAbsent(detector, detectorId); + ModelState modelState = buffer.get(modelId); + + // during maintenance period, stop putting new entries + if (!maintenanceLock.isLocked() && modelState == null) { + if (ADEnabledSetting.isDoorKeeperInCacheEnabled()) { + DoorKeeper doorKeeper = doorKeepers + .computeIfAbsent( + detectorId, + id -> { + // reset every 60 intervals + return new DoorKeeper( + TimeSeriesSettings.DOOR_KEEPER_FOR_CACHE_MAX_INSERTION, + TimeSeriesSettings.DOOR_KEEPER_FALSE_POSITIVE_RATE, + detector.getIntervalDuration().multipliedBy(TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ), + clock + ); + } + ); + + // first hit, ignore + // since door keeper may get reset during maintenance, it is possible + // the entity is still active even though door keeper has no record of + // this model Id. We have to call isActive method to make sure. Otherwise, + // the entity might miss an anomaly result every 60 intervals due to door keeper + // reset. + if (!doorKeeper.mightContain(modelId) && !isActive(detectorId, modelId)) { + doorKeeper.put(modelId); + return null; + } + } + + try { + ModelState state = inActiveEntities.get(modelId, new Callable>() { + @Override + public ModelState call() { + return new ModelState<>(null, modelId, detectorId, ModelType.ENTITY.getName(), clock, 0); + } + }); + + // make sure no model has been stored due to previous race conditions + state.setModel(null); + + // compute updated priority + // We don’t want to admit the latest entity for correctness by throwing out a + // hot entity. We have a priority (time-decayed count) sensitive to + // the number of hits, length of time, and sampling interval. Examples: + // 1) an entity from a 5-minute interval detector that is hit 5 times in the + // past 25 minutes should have an equal chance of using the cache along with + // an entity from a 1-minute interval detector that is hit 5 times in the past + // 5 minutes. + // 2) It might be the case that the frequency of entities changes dynamically + // during run-time. For example, entity A showed up for the first 500 times, + // but entity B showed up for the next 500 times. Our priority should give + // entity B higher priority than entity A as time goes by. + // 3) Entity A will have a higher priority than entity B if A runs + // for a longer time given other things are equal. + // + // We ensure fairness by using periods instead of absolute duration. Entity A + // accessed once three intervals ago should have the same priority with entity B + // accessed once three periods ago, though they belong to detectors of different + // intervals. + + // update state using new priority or create a new one + state.setPriority(buffer.getPriorityTracker().getUpdatedPriority(state.getPriority())); + + // adjust shared memory in case we have used dedicated cache memory for other detectors + if (random.nextInt(maintenanceFreqConstant) == 1) { + tryClearUpMemory(); + } + } catch (Exception e) { + LOG.error(new ParameterizedMessage("Fail to update priority of [{}]", modelId), e); + } + + } + + return modelState; + } + + private Optional> getStateFromInactiveEntiiyCache(String modelId) { + if (modelId == null) { + return Optional.empty(); + } + + // null if not even recorded in inActiveEntities yet because of doorKeeper + return Optional.ofNullable(inActiveEntities.getIfPresent(modelId)); + } + + @Override + public boolean hostIfPossible(AnomalyDetector detector, ModelState toUpdate) { + if (toUpdate == null) { + return false; + } + String modelId = toUpdate.getModelId(); + String detectorId = toUpdate.getId(); + + if (Strings.isEmpty(modelId) || Strings.isEmpty(detectorId)) { + return false; + } + + CacheBuffer buffer = computeBufferIfAbsent(detector, detectorId); + + Optional> state = getStateFromInactiveEntiiyCache(modelId); + if (false == state.isPresent()) { + return false; + } + + ModelState modelState = state.get(); + + float priority = modelState.getPriority(); + + toUpdate.setLastUsedTime(clock.instant()); + toUpdate.setPriority(priority); + + // current buffer's dedicated cache has free slots or can allocate in shared cache + if (buffer.dedicatedCacheAvailable() || memoryTracker.canAllocate(buffer.getMemoryConsumptionPerEntity())) { + // buffer.put will call MemoryTracker.consumeMemory + buffer.put(modelId, toUpdate); + return true; + } + + if (memoryTracker.canAllocate(buffer.getMemoryConsumptionPerEntity())) { + // buffer.put will call MemoryTracker.consumeMemory + buffer.put(modelId, toUpdate); + return true; + } + + // can replace an entity in the same CacheBuffer living in reserved or shared cache + if (buffer.canReplaceWithinDetector(priority)) { + ModelState removed = buffer.replace(modelId, toUpdate); + // null in the case of some other threads have emptied the queue at + // the same time so there is nothing to replace + if (removed != null) { + addIntoInactiveCache(removed); + return true; + } + } + + // If two threads try to remove the same entity and add their own state, the 2nd remove + // returns null and only the first one succeeds. + float scaledPriority = buffer.getPriorityTracker().getScaledPriority(priority); + Triple bufferToRemoveEntity = canReplaceInSharedCache(buffer, scaledPriority); + CacheBuffer bufferToRemove = bufferToRemoveEntity.getLeft(); + String entityModelId = bufferToRemoveEntity.getMiddle(); + ModelState removed = null; + if (bufferToRemove != null && ((removed = bufferToRemove.remove(entityModelId)) != null)) { + buffer.put(modelId, toUpdate); + addIntoInactiveCache(removed); + return true; + } + + return false; + } + + private void addIntoInactiveCache(ModelState removed) { + if (removed == null) { + return; + } + // set last used time for profile API so that we know when an entities is evicted + removed.setLastUsedTime(clock.instant()); + removed.setModel(null); + inActiveEntities.put(removed.getModelId(), removed); + } + + private void addEntity(List destination, Entity entity, String detectorId) { + // It's possible our doorkeepr prevented the entity from entering inactive entities cache + if (entity != null) { + Optional modelId = entity.getModelId(detectorId); + if (modelId.isPresent() && inActiveEntities.getIfPresent(modelId.get()) != null) { + destination.add(entity); + } + } + } + + @Override + public Pair, List> selectUpdateCandidate( + Collection cacheMissEntities, + String detectorId, + AnomalyDetector detector + ) { + List hotEntities = new ArrayList<>(); + List coldEntities = new ArrayList<>(); + + CacheBuffer buffer = activeEnities.get(detectorId); + if (buffer == null) { + // don't want to create side-effects by creating a CacheBuffer + // In current implementation, this branch is impossible as we call + // PriorityCache.get method before invoking this method. The + // PriorityCache.get method creates a CacheBuffer if not present. + // Since this method is public, need to deal with this case in case of misuse. + return Pair.of(hotEntities, coldEntities); + } + + Iterator cacheMissEntitiesIter = cacheMissEntities.iterator(); + // current buffer's dedicated cache has free slots + while (cacheMissEntitiesIter.hasNext() && buffer.dedicatedCacheAvailable()) { + addEntity(hotEntities, cacheMissEntitiesIter.next(), detectorId); + } + + while (cacheMissEntitiesIter.hasNext() && memoryTracker.canAllocate(buffer.getMemoryConsumptionPerEntity())) { + // can allocate in shared cache + // race conditions can happen when multiple threads evaluating this condition. + // This is a problem as our AD memory usage is close to full and we put + // more things than we planned. One model in HCAD is small, + // it is fine we exceed a little. We have regular maintenance to remove + // extra memory usage. + addEntity(hotEntities, cacheMissEntitiesIter.next(), detectorId); + } + + // check if we can replace anything in dedicated or shared cache + // have a copy since we need to do the iteration twice: one for + // dedicated cache and one for shared cache + List otherBufferReplaceCandidates = new ArrayList<>(); + + while (cacheMissEntitiesIter.hasNext()) { + // can replace an entity in the same CacheBuffer living in reserved + // or shared cache + // thread safe as each detector has one thread at one time and only the + // thread can access its buffer. + Entity entity = cacheMissEntitiesIter.next(); + Optional modelId = entity.getModelId(detectorId); + + if (false == modelId.isPresent()) { + continue; + } + + Optional> state = getStateFromInactiveEntiiyCache(modelId.get()); + if (false == state.isPresent()) { + // not even recorded in inActiveEntities yet because of doorKeeper + continue; + } + + ModelState modelState = state.get(); + float priority = modelState.getPriority(); + + if (buffer.canReplaceWithinDetector(priority)) { + addEntity(hotEntities, entity, detectorId); + } else { + // re-evaluate replacement condition in other buffers + otherBufferReplaceCandidates.add(entity); + } + } + + // record current minimum priority among all detectors to save redundant + // scanning of all CacheBuffers + CacheBuffer bufferToRemove = null; + float minPriority = Float.MIN_VALUE; + + // check if we can replace in other CacheBuffer + cacheMissEntitiesIter = otherBufferReplaceCandidates.iterator(); + + while (cacheMissEntitiesIter.hasNext()) { + // If two threads try to remove the same entity and add their own state, the 2nd remove + // returns null and only the first one succeeds. + Entity entity = cacheMissEntitiesIter.next(); + Optional modelId = entity.getModelId(detectorId); + + if (false == modelId.isPresent()) { + continue; + } + + Optional> inactiveState = getStateFromInactiveEntiiyCache(modelId.get()); + if (false == inactiveState.isPresent()) { + // empty state should not stand a chance to replace others + continue; + } + + ModelState state = inactiveState.get(); + + float priority = state.getPriority(); + float scaledPriority = buffer.getPriorityTracker().getScaledPriority(priority); + + if (scaledPriority <= minPriority) { + // not even larger than the minPriority, we can put this to coldEntities + addEntity(coldEntities, entity, detectorId); + continue; + } + + // Float.MIN_VALUE means we need to re-iterate through all CacheBuffers + if (minPriority == Float.MIN_VALUE) { + Triple bufferToRemoveEntity = canReplaceInSharedCache(buffer, scaledPriority); + bufferToRemove = bufferToRemoveEntity.getLeft(); + minPriority = bufferToRemoveEntity.getRight(); + } + + if (bufferToRemove != null) { + addEntity(hotEntities, entity, detectorId); + // reset minPriority after the replacement so that we need to iterate all CacheBuffer + // again + minPriority = Float.MIN_VALUE; + } else { + // after trying everything, we can now safely put this to cold entities list + addEntity(coldEntities, entity, detectorId); + } + } + + return Pair.of(hotEntities, coldEntities); + } + + private CacheBuffer computeBufferIfAbsent(AnomalyDetector detector, String detectorId) { + CacheBuffer buffer = activeEnities.get(detectorId); + if (buffer == null) { + long requiredBytes = getRequiredMemory(detector, dedicatedCacheSize); + if (memoryTracker.canAllocateReserved(requiredBytes)) { + memoryTracker.consumeMemory(requiredBytes, true, Origin.HC_DETECTOR); + long intervalSecs = detector.getIntervalInSeconds(); + + buffer = new CacheBuffer( + dedicatedCacheSize, + intervalSecs, + getRequiredMemory(detector, 1), + memoryTracker, + clock, + modelTtl, + detectorId, + checkpointWriteQueue, + checkpointMaintainQueue, + checkpointIntervalHrs + ); + activeEnities.put(detectorId, buffer); + // There can be race conditions between tryClearUpMemory and + // activeEntities.put above as tryClearUpMemory accesses activeEnities too. + // Put tryClearUpMemory after consumeMemory to prevent that. + tryClearUpMemory(); + } else { + throw new LimitExceededException(detectorId, CommonMessages.MEMORY_LIMIT_EXCEEDED_ERR_MSG); + } + + } + return buffer; + } + + /** + * + * @param detector Detector config accessor + * @param numberOfEntity number of entities + * @return Memory in bytes required for hosting numberOfEntity entities + */ + private long getRequiredMemory(AnomalyDetector detector, int numberOfEntity) { + int dimension = detector.getEnabledFeatureIds().size() * detector.getShingleSize(); + return numberOfEntity * memoryTracker + .estimateTRCFModelSize( + dimension, + numberOfTrees, + TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO, + detector.getShingleSize().intValue(), + true + ); + } + + /** + * Whether the candidate entity can replace any entity in the shared cache. + * We can have race conditions when multiple threads try to evaluate this + * function. The result is that we can have multiple threads thinks they + * can replace entities in the cache. + * + * + * @param originBuffer the CacheBuffer that the entity belongs to (with the same detector Id) + * @param candidatePriority the candidate entity's priority + * @return the CacheBuffer if we can find a CacheBuffer to make room for the candidate entity + */ + private Triple canReplaceInSharedCache(CacheBuffer originBuffer, float candidatePriority) { + CacheBuffer minPriorityBuffer = null; + float minPriority = candidatePriority; + String minPriorityEntityModelId = null; + for (Map.Entry entry : activeEnities.entrySet()) { + CacheBuffer buffer = entry.getValue(); + if (buffer != originBuffer && buffer.canRemove()) { + Optional> priorityEntry = buffer.getPriorityTracker().getMinimumScaledPriority(); + if (!priorityEntry.isPresent()) { + continue; + } + float priority = priorityEntry.get().getValue(); + if (candidatePriority > priority && priority < minPriority) { + minPriority = priority; + minPriorityBuffer = buffer; + minPriorityEntityModelId = priorityEntry.get().getKey(); + } + } + } + return Triple.of(minPriorityBuffer, minPriorityEntityModelId, minPriority); + } + + /** + * Clear up overused memory. Can happen due to race condition or other detectors + * consumes resources from shared memory. + * tryClearUpMemory is ran using AD threadpool because the function is expensive. + */ + private void tryClearUpMemory() { + try { + if (maintenanceLock.tryLock()) { + threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME).execute(() -> clearMemory()); + } else { + threadPool.schedule(() -> { + try { + tryClearUpMemory(); + } catch (Exception e) { + LOG.error("Fail to clear up memory taken by CacheBuffer. Will retry during maintenance."); + } + }, new TimeValue(random.nextInt(90), TimeUnit.SECONDS), TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME); + } + } finally { + if (maintenanceLock.isHeldByCurrentThread()) { + maintenanceLock.unlock(); + } + } + } + + private void clearMemory() { + recalculateUsedMemory(); + long memoryToShed = memoryTracker.memoryToShed(); + PriorityQueue> removalCandiates = null; + if (memoryToShed > 0) { + // sort the triple in an ascending order of priority + removalCandiates = new PriorityQueue<>((x, y) -> Float.compare(x.getLeft(), y.getLeft())); + for (Map.Entry entry : activeEnities.entrySet()) { + CacheBuffer buffer = entry.getValue(); + Optional> priorityEntry = buffer.getPriorityTracker().getMinimumScaledPriority(); + if (!priorityEntry.isPresent()) { + continue; + } + float priority = priorityEntry.get().getValue(); + if (buffer.canRemove()) { + removalCandiates.add(Triple.of(priority, buffer, priorityEntry.get().getKey())); + } + } + } + while (memoryToShed > 0) { + if (false == removalCandiates.isEmpty()) { + Triple toRemove = removalCandiates.poll(); + CacheBuffer minPriorityBuffer = toRemove.getMiddle(); + String minPriorityEntityModelId = toRemove.getRight(); + + ModelState removed = minPriorityBuffer.remove(minPriorityEntityModelId); + memoryToShed -= minPriorityBuffer.getMemoryConsumptionPerEntity(); + addIntoInactiveCache(removed); + + if (minPriorityBuffer.canRemove()) { + // can remove another one + Optional> priorityEntry = minPriorityBuffer.getPriorityTracker().getMinimumScaledPriority(); + if (priorityEntry.isPresent()) { + removalCandiates.add(Triple.of(priorityEntry.get().getValue(), minPriorityBuffer, priorityEntry.get().getKey())); + } + } + } + + if (removalCandiates.isEmpty()) { + break; + } + } + + } + + /** + * Recalculate memory consumption in case of bugs/race conditions when allocating/releasing memory + */ + private void recalculateUsedMemory() { + long reserved = 0; + long shared = 0; + for (Map.Entry entry : activeEnities.entrySet()) { + CacheBuffer buffer = entry.getValue(); + reserved += buffer.getReservedBytes(); + shared += buffer.getBytesInSharedCache(); + } + memoryTracker.syncMemoryState(Origin.HC_DETECTOR, reserved + shared, reserved); + } + + /** + * Maintain active entity's cache and door keepers. + * + * inActiveEntities is a Guava's LRU cache. The data structure itself is + * gonna evict items if they are inactive for 3 days or its maximum size + * reached (1 million entries) + */ + @Override + public void maintenance() { + try { + // clean up memory if we allocate more memory than we should + tryClearUpMemory(); + activeEnities.entrySet().stream().forEach(cacheBufferEntry -> { + String detectorId = cacheBufferEntry.getKey(); + CacheBuffer cacheBuffer = cacheBufferEntry.getValue(); + // remove expired cache buffer + if (cacheBuffer.expired(modelTtl)) { + activeEnities.remove(detectorId); + cacheBuffer.clear(); + } else { + List> removedStates = cacheBuffer.maintenance(); + for (ModelState state : removedStates) { + addIntoInactiveCache(state); + } + } + }); + + maintainInactiveCache(); + + doorKeepers.entrySet().stream().forEach(doorKeeperEntry -> { + String detectorId = doorKeeperEntry.getKey(); + DoorKeeper doorKeeper = doorKeeperEntry.getValue(); + // doorKeeper has its own state ttl + if (doorKeeper.expired(null)) { + doorKeepers.remove(detectorId); + } else { + doorKeeper.maintenance(); + } + }); + } catch (Exception e) { + // will be thrown to ES's transport broadcast handler + throw new TimeSeriesException("Fail to maintain cache", e); + } + + } + + /** + * Permanently deletes models hosted in memory and persisted in index. + * + * @param detectorId id the of the detector for which models are to be permanently deleted + */ + @Override + public void clear(String detectorId) { + if (Strings.isEmpty(detectorId)) { + return; + } + CacheBuffer buffer = activeEnities.remove(detectorId); + if (buffer != null) { + buffer.clear(); + } + checkpointDao.deleteModelCheckpointByDetectorId(detectorId); + doorKeepers.remove(detectorId); + } + + /** + * Get the number of active entities of a detector + * @param detectorId Detector Id + * @return The number of active entities + */ + @Override + public int getActiveEntities(String detectorId) { + CacheBuffer cacheBuffer = activeEnities.get(detectorId); + if (cacheBuffer != null) { + return cacheBuffer.getActiveEntities(); + } + return 0; + } + + /** + * Whether an entity is active or not + * @param detectorId The Id of the detector that an entity belongs to + * @param entityModelId Entity's Model Id + * @return Whether an entity is active or not + */ + @Override + public boolean isActive(String detectorId, String entityModelId) { + CacheBuffer cacheBuffer = activeEnities.get(detectorId); + if (cacheBuffer != null) { + return cacheBuffer.isActive(entityModelId); + } + return false; + } + + @Override + public long getTotalUpdates(String detectorId) { + return Optional + .of(activeEnities) + .map(entities -> entities.get(detectorId)) + .map(buffer -> buffer.getPriorityTracker().getHighestPriorityEntityId()) + .map(entityModelIdOptional -> entityModelIdOptional.get()) + .map(entityModelId -> getTotalUpdates(detectorId, entityModelId)) + .orElse(0L); + } + + @Override + public long getTotalUpdates(String detectorId, String entityModelId) { + CacheBuffer cacheBuffer = activeEnities.get(detectorId); + if (cacheBuffer != null) { + Optional modelOptional = cacheBuffer.getModel(entityModelId); + // TODO: make it work for shingles. samples.size() is not the real shingle + long accumulatedShingles = modelOptional + .flatMap(model -> model.getTrcf()) + .map(trcf -> trcf.getForest()) + .map(rcf -> rcf.getTotalUpdates()) + .orElseGet( + () -> modelOptional.map(model -> model.getSamples()).map(samples -> samples.size()).map(Long::valueOf).orElse(0L) + ); + return accumulatedShingles; + } + return 0L; + } + + /** + * + * @return total active entities in the cache + */ + @Override + public int getTotalActiveEntities() { + AtomicInteger total = new AtomicInteger(); + activeEnities.values().stream().forEach(cacheBuffer -> { total.addAndGet(cacheBuffer.getActiveEntities()); }); + return total.get(); + } + + /** + * Gets modelStates of all model hosted on a node + * + * @return list of modelStates + */ + @Override + public List> getAllModels() { + List> states = new ArrayList<>(); + activeEnities.values().stream().forEach(cacheBuffer -> states.addAll(cacheBuffer.getAllModels())); + return states; + } + + /** + * Gets all of a detector's model sizes hosted on a node + * + * @return a map of model id to its memory size + */ + @Override + public Map getModelSize(String detectorId) { + CacheBuffer cacheBuffer = activeEnities.get(detectorId); + Map res = new HashMap<>(); + if (cacheBuffer != null) { + long size = cacheBuffer.getMemoryConsumptionPerEntity(); + cacheBuffer.getAllModels().forEach(entry -> res.put(entry.getModelId(), size)); + } + return res; + } + + /** + * Return the last active time of an entity's state. + * + * If the entity's state is active in the cache, the value indicates when the cache + * is lastly accessed (get/put). If the entity's state is inactive in the cache, + * the value indicates when the cache state is created or when the entity is evicted + * from active entity cache. + * + * @param detectorId The Id of the detector that an entity belongs to + * @param entityModelId Entity's Model Id + * @return if the entity is in the cache, return the timestamp in epoch + * milliseconds when the entity's state is lastly used. Otherwise, return -1. + */ + @Override + public long getLastActiveMs(String detectorId, String entityModelId) { + CacheBuffer cacheBuffer = activeEnities.get(detectorId); + long lastUsedMs = -1; + if (cacheBuffer != null) { + lastUsedMs = cacheBuffer.getLastUsedTime(entityModelId); + if (lastUsedMs != -1) { + return lastUsedMs; + } + } + ModelState stateInActive = inActiveEntities.getIfPresent(entityModelId); + if (stateInActive != null) { + lastUsedMs = stateInActive.getLastUsedTime().toEpochMilli(); + } + return lastUsedMs; + } + + @Override + public void releaseMemoryForOpenCircuitBreaker() { + maintainInactiveCache(); + + tryClearUpMemory(); + activeEnities.values().stream().forEach(cacheBuffer -> { + if (cacheBuffer.canRemove()) { + ModelState removed = cacheBuffer.remove(); + addIntoInactiveCache(removed); + } + }); + } + + private void maintainInactiveCache() { + if (lastInActiveEntityMaintenance.plus(this.modelTtl).isAfter(clock.instant())) { + // don't scan inactive cache too frequently as it is costly + return; + } + + // force maintenance of the cache. ref: https://tinyurl.com/pyy3p9v6 + inActiveEntities.cleanUp(); + + // // make sure no model has been stored due to bugs + for (ModelState state : inActiveEntities.asMap().values()) { + EntityModel model = state.getModel(); + if (model != null && model.getTrcf().isPresent()) { + LOG.warn(new ParameterizedMessage("Inactive entity's model is null: [{}]. Maybe there are bugs.", state.getModelId())); + state.setModel(null); + } + } + + lastInActiveEntityMaintenance = clock.instant(); + } + + /** + * Called when dedicated cache size changes. Will adjust existing cache buffer's + * cache size + */ + private void setDedicatedCacheSizeListener() { + activeEnities.values().stream().forEach(cacheBuffer -> cacheBuffer.setMinimumCapacity(dedicatedCacheSize)); + } + + private void setCheckpointFreqListener() { + activeEnities.values().stream().forEach(cacheBuffer -> cacheBuffer.setCheckpointIntervalHrs(checkpointIntervalHrs)); + } + + @Override + public List getAllModelProfile(String detectorId) { + CacheBuffer cacheBuffer = activeEnities.get(detectorId); + List res = new ArrayList<>(); + if (cacheBuffer != null) { + long size = cacheBuffer.getMemoryConsumptionPerEntity(); + cacheBuffer.getAllModels().forEach(entry -> { + EntityModel model = entry.getModel(); + Entity entity = null; + if (model != null && model.getEntity().isPresent()) { + entity = model.getEntity().get(); + } + res.add(new ModelProfile(entry.getModelId(), entity, size)); + }); + } + return res; + } + + /** + * Gets an entity's model state + * + * @param detectorId detector id + * @param entityModelId entity model id + * @return the model state + */ + @Override + public Optional getModelProfile(String detectorId, String entityModelId) { + CacheBuffer cacheBuffer = activeEnities.get(detectorId); + if (cacheBuffer != null && cacheBuffer.getModel(entityModelId).isPresent()) { + EntityModel model = cacheBuffer.getModel(entityModelId).get(); + Entity entity = null; + if (model != null && model.getEntity().isPresent()) { + entity = model.getEntity().get(); + } + return Optional.of(new ModelProfile(entityModelId, entity, cacheBuffer.getMemoryConsumptionPerEntity())); + } + return Optional.empty(); + } + + /** + * Throw an IllegalArgumentException even the dedicated size increases cannot + * be fulfilled. + * + * @param newDedicatedCacheSize the new dedicated cache size to validate + */ + private void validateDedicatedCacheSize(Integer newDedicatedCacheSize) { + if (this.dedicatedCacheSize < newDedicatedCacheSize) { + int delta = newDedicatedCacheSize - this.dedicatedCacheSize; + long totalIncreasedBytes = 0; + for (CacheBuffer cacheBuffer : activeEnities.values()) { + totalIncreasedBytes += cacheBuffer.getMemoryConsumptionPerEntity() * delta; + } + + if (false == memoryTracker.canAllocateReserved(totalIncreasedBytes)) { + throw new IllegalArgumentException("We don't have enough memory for the required change"); + } + } + } + + /** + * Get a model state without incurring priority update. Used in maintenance. + * @param detectorId Detector Id + * @param modelId Model Id + * @return Model state + */ + @Override + public Optional> getForMaintainance(String detectorId, String modelId) { + CacheBuffer buffer = activeEnities.get(detectorId); + if (buffer == null) { + return Optional.empty(); + } + return Optional.ofNullable(buffer.getWithoutUpdatePriority(modelId)); + } + + /** + * Remove entity model from active entity buffer and delete checkpoint. Used to clean corrupted model. + * @param detectorId Detector Id + * @param entityModelId Model Id + */ + @Override + public void removeEntityModel(String detectorId, String entityModelId) { + CacheBuffer buffer = activeEnities.get(detectorId); + if (buffer != null) { + ModelState removed = null; + if ((removed = buffer.remove(entityModelId, false)) != null) { + addIntoInactiveCache(removed); + } + } + checkpointDao + .deleteModelCheckpoint( + entityModelId, + ActionListener + .wrap( + r -> LOG.debug(new ParameterizedMessage("Succeeded in deleting checkpoint [{}].", entityModelId)), + e -> LOG.error(new ParameterizedMessage("Failed to delete checkpoint [{}].", entityModelId), e) + ) + ); + } + + private Cache> createInactiveCache(Duration inactiveEntityTtl, int maxInactiveStates) { + return CacheBuilder + .newBuilder() + .expireAfterAccess(inactiveEntityTtl.toHours(), TimeUnit.HOURS) + .maximumSize(maxInactiveStates) + .concurrencyLevel(1) + .build(); + } +} diff --git a/src/main/java/org/opensearch/ad/caching/PriorityTracker.java-e b/src/main/java/org/opensearch/ad/caching/PriorityTracker.java-e new file mode 100644 index 000000000..439d67679 --- /dev/null +++ b/src/main/java/org/opensearch/ad/caching/PriorityTracker.java-e @@ -0,0 +1,359 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.caching; + +import java.time.Clock; +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentSkipListSet; + +import org.apache.commons.lang.builder.EqualsBuilder; +import org.apache.commons.lang.builder.HashCodeBuilder; +import org.apache.commons.lang.builder.ToStringBuilder; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.timeseries.annotation.Generated; + +/** + * A priority tracker for entities. Read docs/entity-priority.pdf for details. + * + * HC detectors use a 1-pass algorithm for estimating heavy hitters in a stream. + * Our method maintains a time-decayed count for each entity, which allows us to + * compare the frequencies/priorities of entities from different detectors in the + * stream. + * This class contains the heavy-hitter tracking logic.  When an entity is hit, + * a user calls PriorityTracker.updatePriority to update the entity's priority. + * The user can find the most frequently occurring entities in the stream using + * PriorityTracker.getTopNEntities. A typical usage is listed below: + * + *
+ * PriorityTracker tracker =  ...
+ *
+ * // at time t1
+ * tracker.updatePriority(entity1);
+ * tracker.updatePriority(entity3);
+ *
+ * //  at time t2
+ * tracker.updatePriority(entity1);
+ * tracker.updatePriority(entity2);
+ *
+ * // we should have entity 1, 2, 3 in order. 2 comes before 3 because it happens later
+ * List<String> top3 = tracker.getTopNEntities(3);
+ * 
+ * + */ +public class PriorityTracker { + private static final Logger LOG = LogManager.getLogger(PriorityTracker.class); + + // data structure for an entity and its priority + static class PriorityNode { + // entity key + private String key; + // time-decayed priority + private float priority; + + PriorityNode(String key, float priority) { + this.priority = priority; + this.key = key; + } + + @Generated + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + if (obj instanceof PriorityNode) { + PriorityNode other = (PriorityNode) obj; + + EqualsBuilder equalsBuilder = new EqualsBuilder(); + equalsBuilder.append(key, other.key); + return equalsBuilder.isEquals(); + } + return false; + } + + @Generated + @Override + public int hashCode() { + return new HashCodeBuilder().append(key).toHashCode(); + } + + @Generated + @Override + public String toString() { + ToStringBuilder builder = new ToStringBuilder(this); + builder.append("key", key); + builder.append("priority", priority); + return builder.toString(); + } + } + + // Comparator between two entities. Used to sort entities in a priority queue + static class PriorityNodeComparator implements Comparator { + + @Override + public int compare(PriorityNode priority, PriorityNode priority2) { + int equality = priority.key.compareTo(priority2.key); + if (equality == 0) { + // this is consistent with PriorityNode's equals method + return 0; + } + // if not equal, first check priority + int cmp = Float.compare(priority.priority, priority2.priority); + if (cmp == 0) { + // if priority is equal, use lexicographical order of key + cmp = equality; + } + return cmp; + } + } + + // key -> Priority node + private final ConcurrentHashMap key2Priority; + // when detector is created.  Can be reset.  Unit: seconds + private long landmarkEpoch; + // a list of priority nodes + private final ConcurrentSkipListSet priorityList; + // Used to get current time. + private final Clock clock; + // length of seconds in one interval.  Used to compute elapsed periods + // since the detector has been enabled. + private final long intervalSecs; + // determines how fast the decay is + // We use the decay constant 0.125. The half life (https://en.wikipedia.org/wiki/Exponential_decay) + // is 8* ln(2). This means the old value falls to one half with roughly 5.6 intervals. + // We chose 0.125 because multiplying 0.125 can be implemented efficiently using 3 right + // shift and the half life is not too fast or slow . + private final int DECAY_CONSTANT; + // the max number of entities to track + private final int maxEntities; + + /** + * Create a priority tracker for a detector. Detector and priority tracker + * have 1:1 mapping. + * + * @param clock Used to get current time. + * @param intervalSecs Detector interval seconds. + * @param landmarkEpoch The epoch time when the priority tracking starts. + * @param maxEntities the max number of entities to track + */ + public PriorityTracker(Clock clock, long intervalSecs, long landmarkEpoch, int maxEntities) { + this.key2Priority = new ConcurrentHashMap<>(); + this.clock = clock; + this.intervalSecs = intervalSecs; + this.landmarkEpoch = landmarkEpoch; + this.priorityList = new ConcurrentSkipListSet<>(new PriorityNodeComparator()); + this.DECAY_CONSTANT = 3; + this.maxEntities = maxEntities; + } + + /** + * Get the minimum priority entity and compute its scaled priority. + * Used to compare entity priorities among detectors. + * @return the minimum priority entity's ID and scaled priority or Optional.empty + * if the priority list is empty + */ + public Optional> getMinimumScaledPriority() { + if (priorityList.isEmpty()) { + return Optional.empty(); + } + PriorityNode smallest = priorityList.first(); + return Optional.of(new SimpleImmutableEntry<>(smallest.key, getScaledPriority(smallest.priority))); + } + + /** + * Get the minimum priority entity and compute its scaled priority. + * Used to compare entity priorities within the same detector. + * @return the minimum priority entity's ID and scaled priority or Optional.empty + * if the priority list is empty + */ + public Optional> getMinimumPriority() { + if (priorityList.isEmpty()) { + return Optional.empty(); + } + PriorityNode smallest = priorityList.first(); + return Optional.of(new SimpleImmutableEntry<>(smallest.key, smallest.priority)); + } + + /** + * + * @return the minimum priority entity's Id or Optional.empty + * if the priority list is empty + */ + public Optional getMinimumPriorityEntityId() { + if (priorityList.isEmpty()) { + return Optional.empty(); + } + return Optional.of(priorityList).map(list -> list.first()).map(node -> node.key); + } + + /** + * + * @return Get maximum priority entity's Id + */ + public Optional getHighestPriorityEntityId() { + if (priorityList.isEmpty()) { + return Optional.empty(); + } + return Optional.of(priorityList).map(list -> list.last()).map(node -> node.key); + } + + /** + * Update an entity's priority with count increment + * @param entityId Entity Id + */ + public void updatePriority(String entityId) { + PriorityNode node = key2Priority.computeIfAbsent(entityId, k -> new PriorityNode(entityId, 0f)); + // reposition this node + this.priorityList.remove(node); + node.priority = getUpdatedPriority(node.priority); + this.priorityList.add(node); + + adjustSizeIfRequired(); + } + + /** + * Associate the specified priority with the entity Id + * @param entityId Entity Id + * @param priority priority + */ + protected void addPriority(String entityId, float priority) { + PriorityNode node = new PriorityNode(entityId, priority); + key2Priority.put(entityId, node); + priorityList.add(node); + + adjustSizeIfRequired(); + } + + /** + * Adjust tracking list if the size exceeded the limit + */ + private void adjustSizeIfRequired() { + if (key2Priority.size() > maxEntities) { + Optional minPriorityId = getMinimumPriorityEntityId(); + if (minPriorityId.isPresent()) { + removePriority(minPriorityId.get()); + } + } + } + + /** + * Remove an entity in the tracker + * @param entityId Entity Id + */ + protected void removePriority(String entityId) { + // remove if the key matches; priority does not matter + priorityList.remove(new PriorityNode(entityId, 0)); + key2Priority.remove(entityId); + } + + /** + * Remove all of entities + */ + protected void clearPriority() { + key2Priority.clear(); + priorityList.clear(); + } + + /** + * Return the updated priority with new priority increment. Used when comparing + * entities' priorities within the same detector. + * + * Each detector maintains an ordered map, filled by entities's accumulated sum of g(i−L), + * which is what this function computes. + * + * g(n) = e^{0.125n}. i is current period. L is the landmark: period 0 when the + * detector is enabled. i - L measures the elapsed periods since detector starts. + * 0.125 is the decay constant. + * + * Since g(i−L) is changing and they are the same for all entities of the same detector, + * we can compare entities' priorities by considering the accumulated sum of g(i−L). + * + * @param oldPriority Existing priority + * + * @return new priority + */ + float getUpdatedPriority(float oldPriority) { + long increment = computeWeightedPriorityIncrement(); + oldPriority += Math.log(1 + Math.exp(increment - oldPriority)); + // if overflow happens, using the most recent decayed count instead. + if (oldPriority == Float.POSITIVE_INFINITY) { + oldPriority = increment; + } + return oldPriority; + } + + /** + * Return the scaled priority. Used when comparing entities' priorities among + * different detectors. + * + * Updated priority = current priority - log(g(t - L)), where g(n) = e^{0.125n}, + * t is current time, and L is the landmark. t - L measures the number of elapsed + * periods relative to the landmark. + * + * When replacing an entity, we query the minimum from each ordered map and + * compute w(i,p) for each minimum entity by scaling the sum by g(p−L). Notice g(p−L) + * can be different if detectors start at different timestamps. The minimum of the minimum + * is selected to be replaced. The number of multi-entity detectors is limited (we consider + * to support ten currently), so the computation is cheap. + * + * @param currentPriority Current priority + * @return the scaled priority + */ + float getScaledPriority(float currentPriority) { + return currentPriority - computeWeightedPriorityIncrement(); + } + + /** + * Compute the weighted priority increment using 0.125n, where n is the number of + * periods relative to the landmark. + * Each detector has its own landmark L: period 0 when the detector is enabled. + * + * @return the weighted priority increment used in the priority update step. + */ + long computeWeightedPriorityIncrement() { + long periods = (clock.instant().getEpochSecond() - landmarkEpoch) / intervalSecs; + return periods >> DECAY_CONSTANT; + } + + /** + * + * @param n the number of entities to return. Can be less than n if there are not enough entities stored. + * @return top entities in the descending order of priority + */ + public List getTopNEntities(int n) { + List entities = new ArrayList<>(); + Iterator entityIterator = priorityList.descendingIterator(); + for (int i = 0; i < n && entityIterator.hasNext(); i++) { + entities.add(entityIterator.next().key); + } + return entities; + } + + /** + * + * @return the number of tracked entities + */ + public int size() { + return key2Priority.size(); + } +} diff --git a/src/main/java/org/opensearch/ad/cluster/ADClusterEventListener.java-e b/src/main/java/org/opensearch/ad/cluster/ADClusterEventListener.java-e new file mode 100644 index 000000000..fd00d9c22 --- /dev/null +++ b/src/main/java/org/opensearch/ad/cluster/ADClusterEventListener.java-e @@ -0,0 +1,98 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.cluster; + +import java.util.concurrent.Semaphore; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.cluster.ClusterChangedEvent; +import org.opensearch.cluster.ClusterStateListener; +import org.opensearch.cluster.node.DiscoveryNodes.Delta; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.gateway.GatewayService; + +public class ADClusterEventListener implements ClusterStateListener { + private static final Logger LOG = LogManager.getLogger(ADClusterEventListener.class); + static final String NOT_RECOVERED_MSG = "Cluster is not recovered yet."; + static final String IN_PROGRESS_MSG = "Cluster state change in progress, return."; + static final String NODE_CHANGED_MSG = "Cluster node changed"; + + private final Semaphore inProgress; + private HashRing hashRing; + private final ClusterService clusterService; + + @Inject + public ADClusterEventListener(ClusterService clusterService, HashRing hashRing) { + this.clusterService = clusterService; + this.clusterService.addListener(this); + this.hashRing = hashRing; + this.inProgress = new Semaphore(1); + } + + @Override + public void clusterChanged(ClusterChangedEvent event) { + + if (event.state().blocks().hasGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) { + LOG.info(NOT_RECOVERED_MSG); + return; + } + + if (!inProgress.tryAcquire()) { + LOG.info(IN_PROGRESS_MSG); + return; + } + + try { + // Init AD version hash ring as early as possible. Some test case may fail as AD + // version hash ring not initialized when test run. + if (!hashRing.isHashRingInited()) { + hashRing + .buildCircles( + ActionListener + .wrap( + r -> LOG.info("Init AD version hash ring successfully"), + e -> LOG.error("Failed to init AD version hash ring") + ) + ); + } + Delta delta = event.nodesDelta(); + + if (delta.removed() || delta.added()) { + LOG.info(NODE_CHANGED_MSG + ", node removed: {}, node added: {}", delta.removed(), delta.added()); + hashRing.addNodeChangeEvent(); + hashRing + .buildCircles( + delta, + ActionListener + .runAfter( + ActionListener + .wrap( + hasRingBuildDone -> { LOG.info("Hash ring build result: {}", hasRingBuildDone); }, + e -> { LOG.error("Failed updating AD version hash ring", e); } + ), + () -> inProgress.release() + ) + ); + } else { + inProgress.release(); + } + } catch (Exception ex) { + // One possible exception is OpenSearchTimeoutException thrown when we fail + // to put checkpoint when ModelManager stops model. + LOG.error("Cluster state change handler has issue(s)", ex); + inProgress.release(); + } + } +} diff --git a/src/main/java/org/opensearch/ad/cluster/ADDataMigrator.java b/src/main/java/org/opensearch/ad/cluster/ADDataMigrator.java index 1340a8f8b..62702d15c 100644 --- a/src/main/java/org/opensearch/ad/cluster/ADDataMigrator.java +++ b/src/main/java/org/opensearch/ad/cluster/ADDataMigrator.java @@ -17,7 +17,7 @@ import static org.opensearch.ad.model.ADTask.TASK_TYPE_FIELD; import static org.opensearch.ad.model.ADTaskType.taskTypeToString; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_DETECTOR_UPPER_LIMIT; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.timeseries.util.RestHandlerUtils.XCONTENT_WITH_TYPE; import static org.opensearch.timeseries.util.RestHandlerUtils.createXContentParserFromRegistry; diff --git a/src/main/java/org/opensearch/ad/cluster/ADDataMigrator.java-e b/src/main/java/org/opensearch/ad/cluster/ADDataMigrator.java-e new file mode 100644 index 000000000..62702d15c --- /dev/null +++ b/src/main/java/org/opensearch/ad/cluster/ADDataMigrator.java-e @@ -0,0 +1,299 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.cluster; + +import static org.opensearch.ad.constant.ADCommonName.DETECTION_STATE_INDEX; +import static org.opensearch.ad.model.ADTask.DETECTOR_ID_FIELD; +import static org.opensearch.ad.model.ADTask.IS_LATEST_FIELD; +import static org.opensearch.ad.model.ADTask.TASK_TYPE_FIELD; +import static org.opensearch.ad.model.ADTaskType.taskTypeToString; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_DETECTOR_UPPER_LIMIT; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.util.RestHandlerUtils.XCONTENT_WITH_TYPE; +import static org.opensearch.timeseries.util.RestHandlerUtils.createXContentParserFromRegistry; + +import java.io.IOException; +import java.time.Instant; +import java.util.Iterator; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskState; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.model.DetectorInternalState; +import org.opensearch.ad.util.ExceptionUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.query.TermsQueryBuilder; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.ExecutorFunction; + +/** + * Migrate AD data to support backward compatibility. + * Currently we need to migrate: + * 1. Detector internal state (used to track realtime job error) to realtime data. + */ +public class ADDataMigrator { + private final Logger logger = LogManager.getLogger(this.getClass()); + private final Client client; + private final ClusterService clusterService; + private final NamedXContentRegistry xContentRegistry; + private final ADIndexManagement detectionIndices; + private final AtomicBoolean dataMigrated; + + public ADDataMigrator( + Client client, + ClusterService clusterService, + NamedXContentRegistry xContentRegistry, + ADIndexManagement detectionIndices + ) { + this.client = client; + this.clusterService = clusterService; + this.xContentRegistry = xContentRegistry; + this.detectionIndices = detectionIndices; + this.dataMigrated = new AtomicBoolean(false); + } + + /** + * Migrate AD data. Currently only need to migrate detector internal state {@link DetectorInternalState} + */ + public void migrateData() { + if (!dataMigrated.getAndSet(true)) { + logger.info("Start migrating AD data"); + + if (!detectionIndices.doesJobIndexExist()) { + logger.info("AD job index doesn't exist, no need to migrate"); + return; + } + + if (detectionIndices.doesStateIndexExist()) { + migrateDetectorInternalStateToRealtimeTask(); + } else { + // If detection index doesn't exist, create index and backfill realtime task. + detectionIndices.initStateIndex(ActionListener.wrap(r -> { + if (r.isAcknowledged()) { + logger.info("Created {} with mappings.", ADCommonName.DETECTION_STATE_INDEX); + migrateDetectorInternalStateToRealtimeTask(); + } else { + String error = "Create index " + ADCommonName.DETECTION_STATE_INDEX + " with mappings not acknowledged"; + logger.warn(error); + } + }, e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException) { + // When migrate data, it's possible that user run some historical analysis and it will create detection + // state index. Then we will see ResourceAlreadyExistsException. + migrateDetectorInternalStateToRealtimeTask(); + } else { + logger.error("Failed to init anomaly detection state index", e); + } + })); + } + } + } + + /** + * Migrate detector internal state to realtime task. + */ + public void migrateDetectorInternalStateToRealtimeTask() { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(new MatchAllQueryBuilder()) + .size(MAX_DETECTOR_UPPER_LIMIT); + SearchRequest searchRequest = new SearchRequest(CommonName.JOB_INDEX).source(searchSourceBuilder); + client.search(searchRequest, ActionListener.wrap(r -> { + if (r == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { + logger.info("No anomaly detector job found, no need to migrate"); + return; + } + ConcurrentLinkedQueue detectorJobs = new ConcurrentLinkedQueue<>(); + Iterator iterator = r.getHits().iterator(); + while (iterator.hasNext()) { + SearchHit searchHit = iterator.next(); + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, searchHit.getSourceRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + AnomalyDetectorJob job = AnomalyDetectorJob.parse(parser); + detectorJobs.add(job); + } catch (IOException e) { + logger.error("Fail to parse AD job " + searchHit.getId(), e); + } + } + logger.info("Total AD jobs to backfill realtime task: {}", detectorJobs.size()); + backfillRealtimeTask(detectorJobs, true); + }, e -> { + if (ExceptionUtil.getErrorMessage(e).contains("all shards failed")) { + // This error may happen when AD job index not ready for query as some nodes not in cluster yet. + // Will recreate realtime task when AD job starts. + logger.warn("No available shards of AD job index, reset dataMigrated as false"); + this.dataMigrated.set(false); + } else if (!(e instanceof IndexNotFoundException)) { + logger.error("Failed to migrate AD data", e); + } + })); + } + + /** + * Backfill realtiem task for realtime job. + * @param detectorJobs realtime AD jobs + * @param backfillAllJob backfill task for all realtime job or not + */ + public void backfillRealtimeTask(ConcurrentLinkedQueue detectorJobs, boolean backfillAllJob) { + AnomalyDetectorJob job = detectorJobs.poll(); + if (job == null) { + logger.info("AD data migration done."); + if (backfillAllJob) { + this.dataMigrated.set(true); + } + return; + } + String jobId = job.getName(); + + ExecutorFunction createRealtimeTaskFunction = () -> { + GetRequest getRequest = new GetRequest(DETECTION_STATE_INDEX, jobId); + client.get(getRequest, ActionListener.wrap(r -> { + if (r != null && r.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + DetectorInternalState detectorState = DetectorInternalState.parse(parser); + createRealtimeADTask(job, detectorState.getError(), detectorJobs, backfillAllJob); + } catch (IOException e) { + logger.error("Failed to parse detector internal state " + jobId, e); + createRealtimeADTask(job, null, detectorJobs, backfillAllJob); + } + } else { + createRealtimeADTask(job, null, detectorJobs, backfillAllJob); + } + }, e -> { + logger.error("Failed to query detector internal state " + jobId, e); + createRealtimeADTask(job, null, detectorJobs, backfillAllJob); + })); + }; + checkIfRealtimeTaskExistsAndBackfill(job, createRealtimeTaskFunction, detectorJobs, backfillAllJob); + } + + private void checkIfRealtimeTaskExistsAndBackfill( + AnomalyDetectorJob job, + ExecutorFunction createRealtimeTaskFunction, + ConcurrentLinkedQueue detectorJobs, + boolean migrateAll + ) { + String jobId = job.getName(); + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, jobId)); + if (job.isEnabled()) { + query.filter(new TermQueryBuilder(IS_LATEST_FIELD, true)); + } + + query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(ADTaskType.REALTIME_TASK_TYPES))); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(1); + SearchRequest searchRequest = new SearchRequest(DETECTION_STATE_INDEX).source(searchSourceBuilder); + client.search(searchRequest, ActionListener.wrap(r -> { + if (r != null && r.getHits().getTotalHits().value > 0) { + // Backfill next realtime job + backfillRealtimeTask(detectorJobs, migrateAll); + return; + } + createRealtimeTaskFunction.execute(); + }, e -> { + if (e instanceof ResourceNotFoundException) { + createRealtimeTaskFunction.execute(); + } + logger.error("Failed to search tasks of detector " + jobId); + })); + } + + private void createRealtimeADTask( + AnomalyDetectorJob job, + String error, + ConcurrentLinkedQueue detectorJobs, + boolean migrateAll + ) { + client.get(new GetRequest(CommonName.CONFIG_INDEX, job.getName()), ActionListener.wrap(r -> { + if (r != null && r.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + AnomalyDetector detector = AnomalyDetector.parse(parser, r.getId()); + ADTaskType taskType = detector.isHighCardinality() + ? ADTaskType.REALTIME_HC_DETECTOR + : ADTaskType.REALTIME_SINGLE_ENTITY; + Instant now = Instant.now(); + String userName = job.getUser() != null ? job.getUser().getName() : null; + ADTask adTask = new ADTask.Builder() + .detectorId(detector.getId()) + .detector(detector) + .error(error) + .isLatest(true) + .taskType(taskType.name()) + .executionStartTime(now) + .taskProgress(0.0f) + .initProgress(0.0f) + .state(ADTaskState.CREATED.name()) + .lastUpdateTime(now) + .startedBy(userName) + .coordinatingNode(null) + .detectionDateRange(null) + .user(job.getUser()) + .build(); + IndexRequest indexRequest = new IndexRequest(DETECTION_STATE_INDEX) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .source(adTask.toXContent(XContentFactory.jsonBuilder(), XCONTENT_WITH_TYPE)); + client.index(indexRequest, ActionListener.wrap(indexResponse -> { + logger.info("Backfill realtime task successfully for detector {}", job.getName()); + backfillRealtimeTask(detectorJobs, migrateAll); + }, ex -> { + logger.error("Failed to backfill realtime task for detector " + job.getName(), ex); + backfillRealtimeTask(detectorJobs, migrateAll); + })); + } catch (IOException e) { + logger.error("Fail to parse detector " + job.getName(), e); + backfillRealtimeTask(detectorJobs, migrateAll); + } + } else { + logger.error("Detector doesn't exist " + job.getName()); + backfillRealtimeTask(detectorJobs, migrateAll); + } + }, e -> { + logger.error("Fail to get detector " + job.getName(), e); + backfillRealtimeTask(detectorJobs, migrateAll); + })); + } + + public void skipMigration() { + this.dataMigrated.set(true); + } + + public boolean isMigrated() { + return this.dataMigrated.get(); + } +} diff --git a/src/main/java/org/opensearch/ad/cluster/ADNodeInfo.java-e b/src/main/java/org/opensearch/ad/cluster/ADNodeInfo.java-e new file mode 100644 index 000000000..e438623d5 --- /dev/null +++ b/src/main/java/org/opensearch/ad/cluster/ADNodeInfo.java-e @@ -0,0 +1,42 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.cluster; + +import org.opensearch.Version; + +/** + * This class records AD version of nodes and whether node is eligible data node to run AD. + */ +public class ADNodeInfo { + // AD plugin version + private Version adVersion; + // Is node eligible to run AD. + private boolean isEligibleDataNode; + + public ADNodeInfo(Version version, boolean isEligibleDataNode) { + this.adVersion = version; + this.isEligibleDataNode = isEligibleDataNode; + } + + public Version getAdVersion() { + return adVersion; + } + + public boolean isEligibleDataNode() { + return isEligibleDataNode; + } + + @Override + public String toString() { + return "ADNodeInfo{" + "version=" + adVersion + ", isEligibleDataNode=" + isEligibleDataNode + '}'; + } +} diff --git a/src/main/java/org/opensearch/ad/cluster/ADVersionUtil.java b/src/main/java/org/opensearch/ad/cluster/ADVersionUtil.java index 463057991..7e880de66 100644 --- a/src/main/java/org/opensearch/ad/cluster/ADVersionUtil.java +++ b/src/main/java/org/opensearch/ad/cluster/ADVersionUtil.java @@ -11,16 +11,15 @@ package org.opensearch.ad.cluster; -import static org.opensearch.ad.constant.ADCommonName.AD_PLUGIN_VERSION_FOR_TEST; - import org.opensearch.Version; +import org.opensearch.timeseries.constant.CommonName; public class ADVersionUtil { public static final int VERSION_SEGMENTS = 3; public static Version fromString(String adVersion) { - if (AD_PLUGIN_VERSION_FOR_TEST.equals(adVersion)) { + if (CommonName.TIME_SERIES_PLUGIN_VERSION_FOR_TEST.equals(adVersion)) { return Version.CURRENT; } return Version.fromString(normalizeVersion(adVersion)); diff --git a/src/main/java/org/opensearch/ad/cluster/ADVersionUtil.java-e b/src/main/java/org/opensearch/ad/cluster/ADVersionUtil.java-e new file mode 100644 index 000000000..7e880de66 --- /dev/null +++ b/src/main/java/org/opensearch/ad/cluster/ADVersionUtil.java-e @@ -0,0 +1,44 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.cluster; + +import org.opensearch.Version; +import org.opensearch.timeseries.constant.CommonName; + +public class ADVersionUtil { + + public static final int VERSION_SEGMENTS = 3; + + public static Version fromString(String adVersion) { + if (CommonName.TIME_SERIES_PLUGIN_VERSION_FOR_TEST.equals(adVersion)) { + return Version.CURRENT; + } + return Version.fromString(normalizeVersion(adVersion)); + } + + public static String normalizeVersion(String adVersion) { + if (adVersion == null) { + throw new IllegalArgumentException("AD version is null"); + } + String[] versions = adVersion.split("\\."); + if (versions.length < VERSION_SEGMENTS) { + throw new IllegalArgumentException("Wrong AD version " + adVersion); + } + StringBuilder normalizedVersion = new StringBuilder(); + normalizedVersion.append(versions[0]); + for (int i = 1; i < VERSION_SEGMENTS; i++) { + normalizedVersion.append("."); + normalizedVersion.append(versions[i]); + } + return normalizedVersion.toString(); + } +} diff --git a/src/main/java/org/opensearch/ad/cluster/ClusterManagerEventListener.java-e b/src/main/java/org/opensearch/ad/cluster/ClusterManagerEventListener.java-e new file mode 100644 index 000000000..9cf1dd905 --- /dev/null +++ b/src/main/java/org/opensearch/ad/cluster/ClusterManagerEventListener.java-e @@ -0,0 +1,136 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.cluster; + +import java.time.Clock; +import java.time.Duration; + +import org.opensearch.ad.cluster.diskcleanup.IndexCleanup; +import org.opensearch.ad.cluster.diskcleanup.ModelCheckpointIndexRetention; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.ad.util.DateUtils; +import org.opensearch.client.Client; +import org.opensearch.cluster.LocalNodeClusterManagerListener; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.component.LifecycleListener; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.threadpool.Scheduler.Cancellable; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +import com.google.common.annotations.VisibleForTesting; + +public class ClusterManagerEventListener implements LocalNodeClusterManagerListener { + + private Cancellable checkpointIndexRetentionCron; + private Cancellable hourlyCron; + private ClusterService clusterService; + private ThreadPool threadPool; + private Client client; + private Clock clock; + private ClientUtil clientUtil; + private DiscoveryNodeFilterer nodeFilter; + private Duration checkpointTtlDuration; + + public ClusterManagerEventListener( + ClusterService clusterService, + ThreadPool threadPool, + Client client, + Clock clock, + ClientUtil clientUtil, + DiscoveryNodeFilterer nodeFilter, + Setting checkpointTtl, + Settings settings + ) { + this.clusterService = clusterService; + this.threadPool = threadPool; + this.client = client; + this.clusterService.addLocalNodeClusterManagerListener(this); + this.clock = clock; + this.clientUtil = clientUtil; + this.nodeFilter = nodeFilter; + + this.checkpointTtlDuration = DateUtils.toDuration(checkpointTtl.get(settings)); + + clusterService.getClusterSettings().addSettingsUpdateConsumer(checkpointTtl, it -> { + this.checkpointTtlDuration = DateUtils.toDuration(it); + cancel(checkpointIndexRetentionCron); + IndexCleanup indexCleanup = new IndexCleanup(client, clientUtil, clusterService); + checkpointIndexRetentionCron = threadPool + .scheduleWithFixedDelay( + new ModelCheckpointIndexRetention(checkpointTtlDuration, clock, indexCleanup), + TimeValue.timeValueHours(24), + executorName() + ); + }); + } + + @Override + public void onClusterManager() { + if (hourlyCron == null) { + hourlyCron = threadPool.scheduleWithFixedDelay(new HourlyCron(client, nodeFilter), TimeValue.timeValueHours(1), executorName()); + clusterService.addLifecycleListener(new LifecycleListener() { + @Override + public void beforeStop() { + cancel(hourlyCron); + hourlyCron = null; + } + }); + } + + if (checkpointIndexRetentionCron == null) { + IndexCleanup indexCleanup = new IndexCleanup(client, clientUtil, clusterService); + checkpointIndexRetentionCron = threadPool + .scheduleWithFixedDelay( + new ModelCheckpointIndexRetention(checkpointTtlDuration, clock, indexCleanup), + TimeValue.timeValueHours(24), + executorName() + ); + clusterService.addLifecycleListener(new LifecycleListener() { + @Override + public void beforeStop() { + cancel(checkpointIndexRetentionCron); + checkpointIndexRetentionCron = null; + } + }); + } + } + + @Override + public void offClusterManager() { + cancel(hourlyCron); + cancel(checkpointIndexRetentionCron); + hourlyCron = null; + checkpointIndexRetentionCron = null; + } + + private void cancel(Cancellable cron) { + if (cron != null) { + cron.cancel(); + } + } + + @VisibleForTesting + protected Cancellable getCheckpointIndexRetentionCron() { + return checkpointIndexRetentionCron; + } + + protected Cancellable getHourlyCron() { + return hourlyCron; + } + + private String executorName() { + return ThreadPool.Names.GENERIC; + } +} diff --git a/src/main/java/org/opensearch/ad/cluster/DailyCron.java-e b/src/main/java/org/opensearch/ad/cluster/DailyCron.java-e new file mode 100644 index 000000000..e2b2b8808 --- /dev/null +++ b/src/main/java/org/opensearch/ad/cluster/DailyCron.java-e @@ -0,0 +1,83 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.cluster; + +import java.time.Clock; +import java.time.Duration; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.timeseries.constant.CommonName; + +@Deprecated +public class DailyCron implements Runnable { + private static final Logger LOG = LogManager.getLogger(DailyCron.class); + protected static final String FIELD_MODEL = "queue"; + static final String CANNOT_DELETE_OLD_CHECKPOINT_MSG = "Cannot delete old checkpoint."; + static final String CHECKPOINT_NOT_EXIST_MSG = "Checkpoint index does not exist."; + static final String CHECKPOINT_DELETED_MSG = "checkpoint docs get deleted"; + + private final Clock clock; + private final Duration checkpointTtl; + private final ClientUtil clientUtil; + + public DailyCron(Clock clock, Duration checkpointTtl, ClientUtil clientUtil) { + this.clock = clock; + this.clientUtil = clientUtil; + this.checkpointTtl = checkpointTtl; + } + + @Override + public void run() { + DeleteByQueryRequest deleteRequest = new DeleteByQueryRequest(ADCommonName.CHECKPOINT_INDEX_NAME) + .setQuery( + QueryBuilders + .boolQuery() + .filter( + QueryBuilders + .rangeQuery(CommonName.TIMESTAMP) + .lte(clock.millis() - checkpointTtl.toMillis()) + .format(ADCommonName.EPOCH_MILLIS_FORMAT) + ) + ) + .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN); + clientUtil + .execute( + DeleteByQueryAction.INSTANCE, + deleteRequest, + ActionListener + .wrap( + response -> { + // if 0 docs get deleted, it means our query cannot find any matching doc + LOG.info("{} " + CHECKPOINT_DELETED_MSG, response.getDeleted()); + }, + exception -> { + if (exception instanceof IndexNotFoundException) { + LOG.info(CHECKPOINT_NOT_EXIST_MSG); + } else { + // Gonna eventually delete in maintenance window. + LOG.error(CANNOT_DELETE_OLD_CHECKPOINT_MSG, exception); + } + } + ) + ); + } + +} diff --git a/src/main/java/org/opensearch/ad/cluster/HashRing.java b/src/main/java/org/opensearch/ad/cluster/HashRing.java index edad3299b..3e6ba0b37 100644 --- a/src/main/java/org/opensearch/ad/cluster/HashRing.java +++ b/src/main/java/org/opensearch/ad/cluster/HashRing.java @@ -11,8 +11,6 @@ package org.opensearch.ad.cluster; -import static org.opensearch.ad.constant.ADCommonName.AD_PLUGIN_NAME; -import static org.opensearch.ad.constant.ADCommonName.AD_PLUGIN_NAME_FOR_TEST; import static org.opensearch.ad.settings.AnomalyDetectorSettings.COOLDOWN_MINUTES; import java.time.Clock; @@ -53,6 +51,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.plugins.PluginInfo; import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import com.google.common.collect.Sets; @@ -269,7 +268,8 @@ private void buildCircles(Set removedNodeIds, Set addedNodeIds, } TreeMap circle = null; for (PluginInfo pluginInfo : plugins.getPluginInfos()) { - if (AD_PLUGIN_NAME.equals(pluginInfo.getName()) || AD_PLUGIN_NAME_FOR_TEST.equals(pluginInfo.getName())) { + if (CommonName.TIME_SERIES_PLUGIN_NAME.equals(pluginInfo.getName()) + || CommonName.TIME_SERIES_PLUGIN_NAME_FOR_TEST.equals(pluginInfo.getName())) { Version version = ADVersionUtil.fromString(pluginInfo.getVersion()); boolean eligibleNode = nodeFilter.isEligibleNode(curNode); if (eligibleNode) { diff --git a/src/main/java/org/opensearch/ad/cluster/HashRing.java-e b/src/main/java/org/opensearch/ad/cluster/HashRing.java-e new file mode 100644 index 000000000..3e6ba0b37 --- /dev/null +++ b/src/main/java/org/opensearch/ad/cluster/HashRing.java-e @@ -0,0 +1,596 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.cluster; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.COOLDOWN_MINUTES; + +import java.time.Clock; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.TreeMap; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.admin.cluster.node.info.NodeInfo; +import org.opensearch.action.admin.cluster.node.info.NodesInfoRequest; +import org.opensearch.action.admin.cluster.node.info.PluginsAndModules; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.SingleStreamModelIdMapper; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.cluster.ClusterChangedEvent; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.routing.Murmur3HashFunction; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.transport.TransportAddress; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.plugins.PluginInfo; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +import com.google.common.collect.Sets; + +public class HashRing { + private static final Logger LOG = LogManager.getLogger(HashRing.class); + // In case of frequent node join/leave, hash ring has a cooldown period say 5 minute. + // Hash ring doesn't respond to more than 1 cluster membership changes within the + // cool-down period. + static final String COOLDOWN_MSG = "Hash ring doesn't respond to cluster state change within the cooldown period."; + private static final String DEFAULT_HASH_RING_MODEL_ID = "DEFAULT_HASHRING_MODEL_ID"; + static final String REMOVE_MODEL_MSG = "Remove model"; + + private final int VIRTUAL_NODE_COUNT = 100; + + // Semaphore to control only 1 thread can build AD hash ring. + private Semaphore buildHashRingSemaphore; + // This field is to track AD version of all nodes. + // Key: node id; Value: AD node info + private Map nodeAdVersions; + // This field records AD version hash ring in realtime way. Historical detection will use this hash ring. + // Key: AD version; Value: hash ring which only contains eligible data nodes + private TreeMap> circles; + // Track if hash ring inited or not. If not inited, the first clusterManager event will try to init it. + private AtomicBoolean hashRingInited; + + // the UTC epoch milliseconds of the most recent successful update of AD circles for realtime AD. + private long lastUpdateForRealtimeAD; + // Cool down period before next hash ring rebuild. We need this as realtime AD needs stable hash ring. + private volatile TimeValue coolDownPeriodForRealtimeAD; + // This field records AD version hash ring with cooldown period. Realtime job will use this hash ring. + // Key: AD version; Value: hash ring which only contains eligible data nodes + private TreeMap> circlesForRealtimeAD; + + // Record node change event. Will check if there is node change event when rebuild AD hash ring with + // cooldown for realtime job. + private ConcurrentLinkedQueue nodeChangeEvents; + + private final DiscoveryNodeFilterer nodeFilter; + private final ClusterService clusterService; + private final ADDataMigrator dataMigrator; + private final Clock clock; + private final Client client; + private final ModelManager modelManager; + + public HashRing( + DiscoveryNodeFilterer nodeFilter, + Clock clock, + Settings settings, + Client client, + ClusterService clusterService, + ADDataMigrator dataMigrator, + ModelManager modelManager + ) { + this.nodeFilter = nodeFilter; + this.buildHashRingSemaphore = new Semaphore(1); + this.clock = clock; + this.coolDownPeriodForRealtimeAD = COOLDOWN_MINUTES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(COOLDOWN_MINUTES, it -> coolDownPeriodForRealtimeAD = it); + + this.lastUpdateForRealtimeAD = 0; + this.client = client; + this.clusterService = clusterService; + this.dataMigrator = dataMigrator; + this.nodeAdVersions = new ConcurrentHashMap<>(); + this.circles = new TreeMap<>(); + this.circlesForRealtimeAD = new TreeMap<>(); + this.hashRingInited = new AtomicBoolean(false); + this.nodeChangeEvents = new ConcurrentLinkedQueue<>(); + this.modelManager = modelManager; + } + + public boolean isHashRingInited() { + return hashRingInited.get(); + } + + /** + * Build AD version based circles with discovery node delta change. Listen to clusterManager event in + * {@link ADClusterEventListener#clusterChanged(ClusterChangedEvent)}. + * Will remove the removed nodes from cache and send request to newly added nodes to get their + * plugin information; then add new nodes to AD version hash ring. + * + * @param delta discovery node delta change + * @param listener action listener + */ + public void buildCircles(DiscoveryNodes.Delta delta, ActionListener listener) { + if (!buildHashRingSemaphore.tryAcquire()) { + LOG.info("AD version hash ring change is in progress. Can't build hash ring for node delta event."); + listener.onResponse(false); + return; + } + Set removedNodeIds = delta.removed() + ? delta.removedNodes().stream().map(DiscoveryNode::getId).collect(Collectors.toSet()) + : null; + Set addedNodeIds = delta.added() ? delta.addedNodes().stream().map(DiscoveryNode::getId).collect(Collectors.toSet()) : null; + buildCircles(removedNodeIds, addedNodeIds, listener); + } + + /** + * Build AD version based circles by comparing with all eligible data nodes. + * 1. Remove nodes which are not eligible now; + * 2. Add nodes which are not in AD version circles. + * @param actionListener action listener + */ + public void buildCircles(ActionListener actionListener) { + if (!buildHashRingSemaphore.tryAcquire()) { + LOG.info("AD version hash ring change is in progress. Can't rebuild hash ring."); + actionListener.onResponse(false); + return; + } + DiscoveryNode[] allNodes = nodeFilter.getAllNodes(); + Set nodeIds = new HashSet<>(); + for (DiscoveryNode node : allNodes) { + nodeIds.add(node.getId()); + } + Set currentNodeIds = nodeAdVersions.keySet(); + Set removedNodeIds = Sets.difference(currentNodeIds, nodeIds); + Set addedNodeIds = Sets.difference(nodeIds, currentNodeIds); + buildCircles(removedNodeIds, addedNodeIds, actionListener); + } + + public void buildCirclesForRealtimeAD() { + if (nodeChangeEvents.isEmpty()) { + return; + } + buildCircles( + ActionListener + .wrap( + r -> { LOG.debug("build circles on AD versions successfully"); }, + e -> { LOG.error("Failed to build circles on AD versions", e); } + ) + ); + } + + /** + * Build AD version hash ring. + * 1. Delete removed nodes from AD version hash ring. + * 2. Add new nodes to AD version hash ring + * + * If fail to acquire semaphore to update AD version hash ring, will return false to + * action listener; otherwise will return true. The "true" response just mean we got + * semaphore and finished rebuilding hash ring, but the hash ring may stay the same. + * Hash ring changed or not depends on if "removedNodeIds" or "addedNodeIds" is empty. + * + * We use different way to build hash ring for realtime job and historical analysis + * 1. For historical analysis,if node removed, we remove it immediately from adVersionCircles + * to avoid new AD task routes to it. If new node added, we add it immediately to adVersionCircles + * to make load more balanced and speed up AD task running. + * 2. For realtime job, we don't record which node running detector's model partition. We just + * use hash ring to get owning node. If we rebuild hash ring frequently, realtime job may get + * different owning node and need to restore model on new owning node. If that happens a lot, + * it may bring heavy load to cluster. So we prefer to wait for some time before next hash ring + * rebuild, we call it cooldown period. The cons is we may have stale hash ring during cooldown + * period. Some node may already been removed from hash ring, then realtime job won't know this + * and still send RCF request to it. If new node added during cooldown period, realtime job won't + * choose it as model partition owning node, thus we may have skewed load on data nodes. + * + * [Important!]: When you call this function, make sure you TRY ACQUIRE adVersionCircleInProgress first. + * Check {@link HashRing#buildCircles(ActionListener)} and + * {@link HashRing#buildCircles(DiscoveryNodes.Delta, ActionListener)} + * + * @param removedNodeIds removed node ids + * @param addedNodeIds added node ids + * @param actionListener action listener + */ + private void buildCircles(Set removedNodeIds, Set addedNodeIds, ActionListener actionListener) { + if (buildHashRingSemaphore.availablePermits() != 0) { + throw new TimeSeriesException("Must get update hash ring semaphore before building AD hash ring"); + } + try { + DiscoveryNode localNode = clusterService.localNode(); + if (removedNodeIds != null && removedNodeIds.size() > 0) { + LOG.info("Node removed: {}", Arrays.toString(removedNodeIds.toArray(new String[0]))); + for (String nodeId : removedNodeIds) { + ADNodeInfo nodeInfo = nodeAdVersions.remove(nodeId); + if (nodeInfo != null && nodeInfo.isEligibleDataNode()) { + removeNodeFromCircles(nodeId, nodeInfo.getAdVersion()); + LOG.info("Remove data node from AD version hash ring: {}", nodeId); + } + } + } + Set allAddedNodes = new HashSet<>(); + + if (addedNodeIds != null) { + allAddedNodes.addAll(addedNodeIds); + } + if (!nodeAdVersions.containsKey(localNode.getId())) { + allAddedNodes.add(localNode.getId()); + } + if (allAddedNodes.size() == 0) { + actionListener.onResponse(true); + // rebuild AD version hash ring with cooldown. + rebuildCirclesForRealtimeAD(); + buildHashRingSemaphore.release(); + return; + } + + LOG.info("Node added: {}", Arrays.toString(allAddedNodes.toArray(new String[0]))); + NodesInfoRequest nodesInfoRequest = new NodesInfoRequest(); + nodesInfoRequest.nodesIds(allAddedNodes.toArray(new String[0])); + nodesInfoRequest.clear().addMetric(NodesInfoRequest.Metric.PLUGINS.metricName()); + + AdminClient admin = client.admin(); + ClusterAdminClient cluster = admin.cluster(); + cluster.nodesInfo(nodesInfoRequest, ActionListener.wrap(r -> { + Map nodesMap = r.getNodesMap(); + if (nodesMap != null && nodesMap.size() > 0) { + for (Map.Entry entry : nodesMap.entrySet()) { + NodeInfo nodeInfo = entry.getValue(); + PluginsAndModules plugins = nodeInfo.getInfo(PluginsAndModules.class); + DiscoveryNode curNode = nodeInfo.getNode(); + if (plugins == null) { + continue; + } + TreeMap circle = null; + for (PluginInfo pluginInfo : plugins.getPluginInfos()) { + if (CommonName.TIME_SERIES_PLUGIN_NAME.equals(pluginInfo.getName()) + || CommonName.TIME_SERIES_PLUGIN_NAME_FOR_TEST.equals(pluginInfo.getName())) { + Version version = ADVersionUtil.fromString(pluginInfo.getVersion()); + boolean eligibleNode = nodeFilter.isEligibleNode(curNode); + if (eligibleNode) { + circle = circles.computeIfAbsent(version, key -> new TreeMap<>()); + LOG.info("Add data node to AD version hash ring: {}", curNode.getId()); + } + nodeAdVersions.put(curNode.getId(), new ADNodeInfo(version, eligibleNode)); + break; + } + } + if (circle != null) { + for (int i = 0; i < VIRTUAL_NODE_COUNT; i++) { + circle.put(Murmur3HashFunction.hash(curNode.getId() + i), curNode); + } + } + } + } + LOG.info("All nodes with known AD version: {}", nodeAdVersions); + + // rebuild AD version hash ring with cooldown after all new node added. + rebuildCirclesForRealtimeAD(); + + if (!dataMigrator.isMigrated() && circles.size() > 0) { + // Find owning node with highest AD version to make sure the data migration logic be compatible to + // latest AD version when upgrade. + Optional owningNode = getOwningNodeWithHighestAdVersion(DEFAULT_HASH_RING_MODEL_ID); + String localNodeId = localNode.getId(); + if (owningNode.isPresent() && localNodeId.equals(owningNode.get().getId())) { + dataMigrator.migrateData(); + } else { + dataMigrator.skipMigration(); + } + } + buildHashRingSemaphore.release(); + hashRingInited.set(true); + actionListener.onResponse(true); + }, e -> { + buildHashRingSemaphore.release(); + actionListener.onFailure(e); + LOG.error("Fail to get node info to build AD version hash ring", e); + })); + } catch (Exception e) { + LOG.error("Failed to build AD version circles", e); + buildHashRingSemaphore.release(); + actionListener.onFailure(e); + } + } + + private void removeNodeFromCircles(String nodeId, Version adVersion) { + if (adVersion != null) { + TreeMap circle = this.circles.get(adVersion); + List deleted = new ArrayList<>(); + for (Map.Entry entry : circle.entrySet()) { + if (entry.getValue().getId().equals(nodeId)) { + deleted.add(entry.getKey()); + } + } + if (deleted.size() == circle.size()) { + circles.remove(adVersion); + } else { + for (Integer key : deleted) { + circle.remove(key); + } + } + } + } + + private void rebuildCirclesForRealtimeAD() { + // Check if it's eligible to rebuild hash ring with cooldown + if (eligibleToRebuildCirclesForRealtimeAD()) { + LOG.info("Rebuild AD hash ring for realtime AD with cooldown, nodeChangeEvents size {}", nodeChangeEvents.size()); + int size = nodeChangeEvents.size(); + TreeMap> newCircles = new TreeMap<>(); + for (Map.Entry> entry : circles.entrySet()) { + newCircles.put(entry.getKey(), new TreeMap<>(entry.getValue())); + } + circlesForRealtimeAD = newCircles; + lastUpdateForRealtimeAD = clock.millis(); + LOG.info("Build AD version hash ring successfully"); + String localNodeId = clusterService.localNode().getId(); + Set modelIds = modelManager.getAllModelIds(); + for (String modelId : modelIds) { + Optional node = getOwningNodeWithSameLocalAdVersionForRealtimeAD(modelId); + if (node.isPresent() && !node.get().getId().equals(localNodeId)) { + LOG.info(REMOVE_MODEL_MSG + " {}", modelId); + modelManager + .stopModel( + // stopModel will clear model cache + SingleStreamModelIdMapper.getDetectorIdForModelId(modelId), + modelId, + ActionListener + .wrap( + r -> LOG.info("Stopped model [{}] with response [{}]", modelId, r), + e -> LOG.error("Fail to stop model " + modelId, e) + ) + ); + } + } + // It's possible that multiple threads add new event to nodeChangeEvents, + // but this is the only place to consume/poll the event and there is only + // one thread poll it as we are using adVersionCircleInProgress semaphore(1) + // to control only 1 thread build hash ring. + while (size-- > 0) { + Boolean poll = nodeChangeEvents.poll(); + if (poll == null) { + break; + } + } + } + } + + /** + * Check if it's eligible to rebuilt hash ring now. + * It's eligible if: + * 1. There is node change event not consumed, and + * 2. Have passed cool down period from last hash ring update time. + * + * Check {@link org.opensearch.ad.settings.AnomalyDetectorSettings#COOLDOWN_MINUTES} about + * cool down settings. + * + * Why we need to wait for some cooldown period before rebuilding hash ring? + * This is for realtime detection. In realtime detection, we rely on hash ring to get + * owning node for RCF model partitions. It's stateless, that means we don't record + * which node is running RCF partition for the detector. That requires a stable hash + * ring. If hash ring changes, it's possible that the next job run will use a different + * node to run RCF partition. Then we need to restore model on the new node and clean up + * old model partitions on old node. That model migration between nodes may bring heavy + * load to cluster. + * + * @return true if it's eligible to rebuild hash ring + */ + protected boolean eligibleToRebuildCirclesForRealtimeAD() { + // Check if there is any node change event + if (nodeChangeEvents.isEmpty() && !circlesForRealtimeAD.isEmpty()) { + return false; + } + + // Check cooldown period + if (clock.millis() - lastUpdateForRealtimeAD <= coolDownPeriodForRealtimeAD.getMillis()) { + LOG.debug(COOLDOWN_MSG); + return false; + } + return true; + } + + /** + * Get owning node with highest AD version circle. + * @param modelId model id + * @return owning node + */ + public Optional getOwningNodeWithHighestAdVersion(String modelId) { + int modelHash = Murmur3HashFunction.hash(modelId); + Map.Entry> versionTreeMapEntry = circles.lastEntry(); + if (versionTreeMapEntry == null) { + return Optional.empty(); + } + TreeMap adVersionCircle = versionTreeMapEntry.getValue(); + Map.Entry entry = adVersionCircle.higherEntry(modelHash); + return Optional.ofNullable(Optional.ofNullable(entry).orElse(adVersionCircle.firstEntry())).map(x -> x.getValue()); + } + + /** + * Get owning node with same AD version of local node. + * @param modelId model id + * @param function consumer function + * @param listener action listener + * @param listener response type + */ + public void buildAndGetOwningNodeWithSameLocalAdVersion( + String modelId, + Consumer> function, + ActionListener listener + ) { + buildCircles(ActionListener.wrap(r -> { + DiscoveryNode localNode = clusterService.localNode(); + Version adVersion = nodeAdVersions.containsKey(localNode.getId()) ? getAdVersion(localNode.getId()) : Version.CURRENT; + Optional owningNode = getOwningNodeWithSameAdVersionDirectly(modelId, adVersion, false); + function.accept(owningNode); + }, e -> listener.onFailure(e))); + } + + public Optional getOwningNodeWithSameLocalAdVersionForRealtimeAD(String modelId) { + try { + DiscoveryNode localNode = clusterService.localNode(); + Version adVersion = nodeAdVersions.containsKey(localNode.getId()) ? getAdVersion(localNode.getId()) : Version.CURRENT; + Optional owningNode = getOwningNodeWithSameAdVersionDirectly(modelId, adVersion, true); + // rebuild hash ring + buildCirclesForRealtimeAD(); + return owningNode; + } catch (Exception e) { + LOG.error("Failed to get owning node with same local AD version", e); + return Optional.empty(); + } + } + + private Optional getOwningNodeWithSameAdVersionDirectly(String modelId, Version adVersion, boolean forRealtime) { + int modelHash = Murmur3HashFunction.hash(modelId); + TreeMap adVersionCircle = forRealtime ? circlesForRealtimeAD.get(adVersion) : circles.get(adVersion); + if (adVersionCircle != null) { + Map.Entry entry = adVersionCircle.higherEntry(modelHash); + return Optional.ofNullable(Optional.ofNullable(entry).orElse(adVersionCircle.firstEntry())).map(x -> x.getValue()); + } + return Optional.empty(); + } + + public void getNodesWithSameLocalAdVersion(Consumer function, ActionListener listener) { + buildCircles(ActionListener.wrap(updated -> { + DiscoveryNode localNode = clusterService.localNode(); + Version adVersion = nodeAdVersions.containsKey(localNode.getId()) ? getAdVersion(localNode.getId()) : Version.CURRENT; + Set nodes = getNodesWithSameAdVersion(adVersion, false); + if (!nodeAdVersions.containsKey(localNode.getId())) { + nodes.add(localNode); + } + // Make sure listener return in function + function.accept(nodes.toArray(new DiscoveryNode[0])); + }, e -> listener.onFailure(e))); + } + + public DiscoveryNode[] getNodesWithSameLocalAdVersion() { + DiscoveryNode localNode = clusterService.localNode(); + Version adVersion = nodeAdVersions.containsKey(localNode.getId()) ? getAdVersion(localNode.getId()) : Version.CURRENT; + Set nodes = getNodesWithSameAdVersion(adVersion, false); + // rebuild hash ring + buildCirclesForRealtimeAD(); + return nodes.toArray(new DiscoveryNode[0]); + } + + protected Set getNodesWithSameAdVersion(Version adVersion, boolean forRealtime) { + TreeMap circle = forRealtime ? circlesForRealtimeAD.get(adVersion) : circles.get(adVersion); + Set nodeIds = new HashSet<>(); + Set nodes = new HashSet<>(); + if (circle == null) { + return nodes; + } + circle.entrySet().stream().forEach(e -> { + DiscoveryNode discoveryNode = e.getValue(); + if (!nodeIds.contains(discoveryNode.getId())) { + nodeIds.add(discoveryNode.getId()); + nodes.add(discoveryNode); + } + }); + return nodes; + } + + /** + * Get AD version. + * @param nodeId node id + * @return AD version + */ + public Version getAdVersion(String nodeId) { + ADNodeInfo adNodeInfo = nodeAdVersions.get(nodeId); + return adNodeInfo == null ? null : adNodeInfo.getAdVersion(); + } + + /** + * Get node by transport address. + * If transport address is null, return local node; otherwise, filter current eligible data nodes + * with IP address. If no node found, will return Optional.empty() + * + * @param address transport address + * @return discovery node + */ + public Optional getNodeByAddress(TransportAddress address) { + if (address == null) { + // If remote address of transport request is null, that means remote node is local node. + return Optional.of(clusterService.localNode()); + } + String ipAddress = getIpAddress(address); + DiscoveryNode[] allNodes = nodeFilter.getAllNodes(); + + // Can't handle this edge case for BWC of AD1.0: mixed cluster with AD1.0 and Version after 1.1. + // Start multiple OpenSearch processes on same IP, some run AD 1.0, some run new AD + // on or after 1.1. As we ignore port number in transport address, just look for node + // with IP like "127.0.0.1", so it's possible that we get wrong node as all nodes have + // same IP. + for (DiscoveryNode node : allNodes) { + if (getIpAddress(node.getAddress()).equals(ipAddress)) { + return Optional.ofNullable(node); + } + } + return Optional.empty(); + } + + /** + * Get IP address from transport address. + * TransportAddress.toString() example: 100.200.100.200:12345 + * @param address transport address + * @return IP address + */ + private String getIpAddress(TransportAddress address) { + // Ignore port number as it may change, just use ip to look for node + return address.toString().split(":")[0]; + } + + /** + * Get all eligible data nodes whose AD versions are known in AD version based hash ring. + * @param function consumer function + * @param listener action listener + * @param action listener response type + */ + public void getAllEligibleDataNodesWithKnownAdVersion(Consumer function, ActionListener listener) { + buildCircles(ActionListener.wrap(r -> { + DiscoveryNode[] eligibleDataNodes = nodeFilter.getEligibleDataNodes(); + List allNodes = new ArrayList<>(); + for (DiscoveryNode node : eligibleDataNodes) { + if (nodeAdVersions.containsKey(node.getId())) { + allNodes.add(node); + } + } + // Make sure listener return in function + function.accept(allNodes.toArray(new DiscoveryNode[0])); + }, e -> listener.onFailure(e))); + } + + /** + * Put node change events in node change event queue. Will poll event from this queue when rebuild hash ring + * for realtime task. + * We track all node change events in case some race condition happen and we miss adding some node to hash + * ring. + */ + public void addNodeChangeEvent() { + this.nodeChangeEvents.add(true); + } +} diff --git a/src/main/java/org/opensearch/ad/cluster/HourlyCron.java-e b/src/main/java/org/opensearch/ad/cluster/HourlyCron.java-e new file mode 100644 index 000000000..a81156bb0 --- /dev/null +++ b/src/main/java/org/opensearch/ad/cluster/HourlyCron.java-e @@ -0,0 +1,57 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.cluster; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.FailedNodeException; +import org.opensearch.ad.transport.CronAction; +import org.opensearch.ad.transport.CronRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +public class HourlyCron implements Runnable { + private static final Logger LOG = LogManager.getLogger(HourlyCron.class); + static final String SUCCEEDS_LOG_MSG = "Hourly maintenance succeeds"; + static final String NODE_EXCEPTION_LOG_MSG = "Hourly maintenance of node has exception"; + static final String EXCEPTION_LOG_MSG = "Hourly maintenance has exception."; + private DiscoveryNodeFilterer nodeFilter; + private Client client; + + public HourlyCron(Client client, DiscoveryNodeFilterer nodeFilter) { + this.nodeFilter = nodeFilter; + this.client = client; + } + + @Override + public void run() { + DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); + + // we also add the cancel query function here based on query text from the negative cache. + + // Length of detector id is 20. Here we create a random string as request id to get hash with + // HashRing, then we can control some maintaining task to just run on one data node. Read + // ADTaskManager#maintainRunningHistoricalTasks for more details. + CronRequest modelDeleteRequest = new CronRequest(dataNodes); + client.execute(CronAction.INSTANCE, modelDeleteRequest, ActionListener.wrap(response -> { + if (response.hasFailures()) { + for (FailedNodeException failedNodeException : response.failures()) { + LOG.warn(NODE_EXCEPTION_LOG_MSG, failedNodeException); + } + } else { + LOG.info(SUCCEEDS_LOG_MSG); + } + }, exception -> { LOG.error(EXCEPTION_LOG_MSG, exception); })); + } +} diff --git a/src/main/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanup.java-e b/src/main/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanup.java-e new file mode 100644 index 000000000..325361aec --- /dev/null +++ b/src/main/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanup.java-e @@ -0,0 +1,126 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.cluster.diskcleanup; + +import java.util.Arrays; +import java.util.Objects; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.admin.indices.stats.CommonStats; +import org.opensearch.action.admin.indices.stats.IndicesStatsRequest; +import org.opensearch.action.admin.indices.stats.IndicesStatsResponse; +import org.opensearch.action.admin.indices.stats.ShardStats; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.index.store.StoreStats; + +/** + * Clean up the old docs for indices. + */ +public class IndexCleanup { + private static final Logger LOG = LogManager.getLogger(IndexCleanup.class); + + private final Client client; + private final ClientUtil clientUtil; + private final ClusterService clusterService; + + public IndexCleanup(Client client, ClientUtil clientUtil, ClusterService clusterService) { + this.client = client; + this.clientUtil = clientUtil; + this.clusterService = clusterService; + } + + /** + * delete docs when shard size is bigger than max limitation. + * @param indexName index name + * @param maxShardSize max shard size + * @param queryForDeleteByQueryRequest query request + * @param listener action listener + */ + public void deleteDocsBasedOnShardSize( + String indexName, + long maxShardSize, + QueryBuilder queryForDeleteByQueryRequest, + ActionListener listener + ) { + + if (!clusterService.state().getRoutingTable().hasIndex(indexName)) { + LOG.debug("skip as the index:{} doesn't exist", indexName); + return; + } + + ActionListener indicesStatsResponseListener = ActionListener.wrap(indicesStatsResponse -> { + // Check if any shard size is bigger than maxShardSize + boolean cleanupNeeded = Arrays + .stream(indicesStatsResponse.getShards()) + .map(ShardStats::getStats) + .filter(Objects::nonNull) + .map(CommonStats::getStore) + .filter(Objects::nonNull) + .map(StoreStats::getSizeInBytes) + .anyMatch(size -> size > maxShardSize); + + if (cleanupNeeded) { + deleteDocsByQuery( + indexName, + queryForDeleteByQueryRequest, + ActionListener.wrap(r -> listener.onResponse(true), listener::onFailure) + ); + } else { + listener.onResponse(false); + } + }, listener::onFailure); + + getCheckpointShardStoreStats(indexName, indicesStatsResponseListener); + } + + private void getCheckpointShardStoreStats(String indexName, ActionListener listener) { + IndicesStatsRequest indicesStatsRequest = new IndicesStatsRequest(); + indicesStatsRequest.store(); + indicesStatsRequest.indices(indexName); + client.admin().indices().stats(indicesStatsRequest, listener); + } + + /** + * Delete docs based on query request + * @param indexName index name + * @param queryForDeleteByQueryRequest query request + * @param listener action listener + */ + public void deleteDocsByQuery(String indexName, QueryBuilder queryForDeleteByQueryRequest, ActionListener listener) { + DeleteByQueryRequest deleteRequest = new DeleteByQueryRequest(indexName) + .setQuery(queryForDeleteByQueryRequest) + .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN) + .setRefresh(true); + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + clientUtil.execute(DeleteByQueryAction.INSTANCE, deleteRequest, ActionListener.wrap(response -> { + long deleted = response.getDeleted(); + if (deleted > 0) { + // if 0 docs get deleted, it means our query cannot find any matching doc + // or the index does not exist at all + LOG.info("{} docs are deleted for index:{}", deleted, indexName); + } + listener.onResponse(response.getDeleted()); + }, listener::onFailure)); + } + + } +} diff --git a/src/main/java/org/opensearch/ad/cluster/diskcleanup/ModelCheckpointIndexRetention.java-e b/src/main/java/org/opensearch/ad/cluster/diskcleanup/ModelCheckpointIndexRetention.java-e new file mode 100644 index 000000000..c966c7b78 --- /dev/null +++ b/src/main/java/org/opensearch/ad/cluster/diskcleanup/ModelCheckpointIndexRetention.java-e @@ -0,0 +1,120 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.cluster.diskcleanup; + +import java.time.Clock; +import java.time.Duration; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.timeseries.constant.CommonName; + +/** + * Model checkpoints cleanup of multi-entity detectors. + *

Problem: + * In multi-entity detectors, we can have thousands, even millions of entities, of which the model checkpoints will consume + * lots of disk resources. To protect the our disk usage, the checkpoint index size will be limited with specified threshold. + * Once its size exceeds the threshold, the model checkpoints cleanup process will be activated. + *

+ *

Solution: + * Before multi-entity detectors, there is daily cron job to clean up the inactive checkpoints longer than some configurable days. + * We will keep the this logic, and add new clean up way based on shard size. + *

+ */ +public class ModelCheckpointIndexRetention implements Runnable { + private static final Logger LOG = LogManager.getLogger(ModelCheckpointIndexRetention.class); + + // The recommended max shard size is 50G, we don't wanna our index exceeds this number + private static final long MAX_SHARD_SIZE_IN_BYTE = 50 * 1024 * 1024 * 1024L; + // We can't clean up all of the checkpoints. At least keep models for 1 day + private static final Duration MINIMUM_CHECKPOINT_TTL = Duration.ofDays(1); + static final String CHECKPOINT_NOT_EXIST_MSG = "Checkpoint index does not exist."; + + private final Duration defaultCheckpointTtl; + private final Clock clock; + private final IndexCleanup indexCleanup; + + public ModelCheckpointIndexRetention(Duration defaultCheckpointTtl, Clock clock, IndexCleanup indexCleanup) { + this.defaultCheckpointTtl = defaultCheckpointTtl; + this.clock = clock; + this.indexCleanup = indexCleanup; + } + + @Override + public void run() { + indexCleanup + .deleteDocsByQuery( + ADCommonName.CHECKPOINT_INDEX_NAME, + QueryBuilders + .boolQuery() + .filter( + QueryBuilders + .rangeQuery(CommonName.TIMESTAMP) + .lte(clock.millis() - defaultCheckpointTtl.toMillis()) + .format(ADCommonName.EPOCH_MILLIS_FORMAT) + ), + ActionListener + .wrap( + response -> { cleanupBasedOnShardSize(defaultCheckpointTtl.minusDays(1)); }, + // The docs will be deleted in next scheduled windows. No need for retrying. + exception -> LOG.error("delete docs by query fails for checkpoint index", exception) + ) + ); + + } + + private void cleanupBasedOnShardSize(Duration cleanUpTtl) { + indexCleanup + .deleteDocsBasedOnShardSize( + ADCommonName.CHECKPOINT_INDEX_NAME, + MAX_SHARD_SIZE_IN_BYTE, + QueryBuilders + .boolQuery() + .filter( + QueryBuilders + .rangeQuery(CommonName.TIMESTAMP) + .lte(clock.millis() - cleanUpTtl.toMillis()) + .format(ADCommonName.EPOCH_MILLIS_FORMAT) + ), + ActionListener.wrap(cleanupNeeded -> { + if (cleanupNeeded) { + if (cleanUpTtl.equals(MINIMUM_CHECKPOINT_TTL)) { + return; + } + + Duration nextCleanupTtl = cleanUpTtl.minusDays(1); + if (nextCleanupTtl.compareTo(MINIMUM_CHECKPOINT_TTL) < 0) { + nextCleanupTtl = MINIMUM_CHECKPOINT_TTL; + } + cleanupBasedOnShardSize(nextCleanupTtl); + } else { + LOG.debug("clean up not needed anymore for checkpoint index"); + } + }, + // The docs will be deleted in next scheduled windows. No need for retrying. + exception -> { + if (exception instanceof IndexNotFoundException) { + // the method will be called hourly + // don't log stack trace as most of OpenSearch domains have no AD installed + LOG.debug(CHECKPOINT_NOT_EXIST_MSG); + } else { + LOG.error("checkpoint index retention based on shard size fails", exception); + } + } + ) + ); + } +} diff --git a/src/main/java/org/opensearch/ad/constant/ADCommonMessages.java-e b/src/main/java/org/opensearch/ad/constant/ADCommonMessages.java-e new file mode 100644 index 000000000..e20dc8fd1 --- /dev/null +++ b/src/main/java/org/opensearch/ad/constant/ADCommonMessages.java-e @@ -0,0 +1,85 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.constant; + +import static org.opensearch.ad.constant.ADCommonName.CUSTOM_RESULT_INDEX_PREFIX; + +public class ADCommonMessages { + public static final String AD_ID_MISSING_MSG = "AD ID is missing"; + public static final String MODEL_ID_MISSING_MSG = "Model ID is missing"; + public static final String HASH_ERR_MSG = "Cannot find an RCF node. Hashing does not work."; + public static final String NO_CHECKPOINT_ERR_MSG = "No checkpoints found for model id "; + public static final String FEATURE_NOT_AVAILABLE_ERR_MSG = "No Feature in current detection window."; + public static final String DISABLED_ERR_MSG = + "AD functionality is disabled. To enable update plugins.anomaly_detection.enabled to true"; + public static String FAIL_TO_PARSE_DETECTOR_MSG = "Fail to parse detector with id: "; + public static String FAIL_TO_GET_PROFILE_MSG = "Fail to get profile for detector "; + public static String FAIL_TO_GET_TOTAL_ENTITIES = "Failed to get total entities for detector "; + public static String FAIL_TO_GET_USER_INFO = "Unable to get user information from detector "; + public static String NO_PERMISSION_TO_ACCESS_DETECTOR = "User does not have permissions to access detector: "; + public static String CATEGORICAL_FIELD_NUMBER_SURPASSED = "We don't support categorical fields more than "; + public static String EMPTY_PROFILES_COLLECT = "profiles to collect are missing or invalid"; + public static String FAIL_FETCH_ERR_MSG = "Fail to fetch profile for "; + public static String DETECTOR_IS_RUNNING = "Detector is already running"; + public static String DETECTOR_MISSING = "Detector is missing"; + public static String AD_TASK_ACTION_MISSING = "AD task action is missing"; + public static final String INDEX_NOT_FOUND = "index does not exist"; + public static final String NOT_EXISTENT_VALIDATION_TYPE = "The given validation type doesn't exist"; + public static final String UNSUPPORTED_PROFILE_TYPE = "Unsupported profile types"; + + public static final String REQUEST_THROTTLED_MSG = "Request throttled. Please try again later."; + public static String NULL_DETECTION_INTERVAL = "Detection interval should be set"; + public static String INVALID_SHINGLE_SIZE = "Shingle size must be a positive integer"; + public static String INVALID_DETECTION_INTERVAL = "Detection interval must be a positive integer"; + public static String EXCEED_HISTORICAL_ANALYSIS_LIMIT = "Exceed max historical analysis limit per node"; + public static String NO_ELIGIBLE_NODE_TO_RUN_DETECTOR = "No eligible node to run detector "; + public static String EMPTY_STALE_RUNNING_ENTITIES = "Empty stale running entities"; + public static String CAN_NOT_FIND_LATEST_TASK = "can't find latest task"; + public static String NO_ENTITY_FOUND = "No entity found"; + public static String HISTORICAL_ANALYSIS_CANCELLED = "Historical analysis cancelled by user"; + public static String HC_DETECTOR_TASK_IS_UPDATING = "HC detector task is updating"; + public static String INVALID_TIME_CONFIGURATION_UNITS = "Time unit %s is not supported"; + public static String FAIL_TO_GET_DETECTOR = "Fail to get detector"; + public static String FAIL_TO_GET_DETECTOR_INFO = "Fail to get detector info"; + public static String FAIL_TO_CREATE_DETECTOR = "Fail to create detector"; + public static String FAIL_TO_UPDATE_DETECTOR = "Fail to update detector"; + public static String FAIL_TO_PREVIEW_DETECTOR = "Fail to preview detector"; + public static String FAIL_TO_START_DETECTOR = "Fail to start detector"; + public static String FAIL_TO_STOP_DETECTOR = "Fail to stop detector"; + public static String FAIL_TO_DELETE_DETECTOR = "Fail to delete detector"; + public static String FAIL_TO_DELETE_AD_RESULT = "Fail to delete anomaly result"; + public static String FAIL_TO_GET_STATS = "Fail to get stats"; + public static String FAIL_TO_SEARCH = "Fail to search"; + + public static String WINDOW_DELAY_REC = + "Latest seen data point is at least %d minutes ago, consider changing window delay to at least %d minutes."; + public static String TIME_FIELD_NOT_ENOUGH_HISTORICAL_DATA = + "There isn't enough historical data found with current timefield selected."; + public static String DETECTOR_INTERVAL_REC = + "The selected detector interval might collect sparse data. Consider changing interval length to: "; + public static String RAW_DATA_TOO_SPARSE = + "Source index data is potentially too sparse for model training. Consider changing interval length or ingesting more data"; + public static String MODEL_VALIDATION_FAILED_UNEXPECTEDLY = "Model validation experienced issues completing."; + public static String FILTER_QUERY_TOO_SPARSE = "Data is too sparse after data filter is applied. Consider changing the data filter"; + public static String CATEGORY_FIELD_TOO_SPARSE = + "Data is most likely too sparse with the given category fields. Consider revising category field/s or ingesting more data "; + public static String CATEGORY_FIELD_NO_DATA = + "No entity was found with the given categorical fields. Consider revising category field/s or ingesting more data"; + public static String FEATURE_QUERY_TOO_SPARSE = + "Data is most likely too sparse when given feature queries are applied. Consider revising feature queries."; + public static String TIMEOUT_ON_INTERVAL_REC = "Timed out getting interval recommendation"; + + public static final String NO_MODEL_ERR_MSG = "No RCF models are available either because RCF" + + " models are not ready or all nodes are unresponsive or the system might have bugs."; + public static String INVALID_RESULT_INDEX_PREFIX = "Result index must start with " + CUSTOM_RESULT_INDEX_PREFIX; + +} diff --git a/src/main/java/org/opensearch/ad/constant/ADCommonName.java b/src/main/java/org/opensearch/ad/constant/ADCommonName.java index c0e204b58..3a97db889 100644 --- a/src/main/java/org/opensearch/ad/constant/ADCommonName.java +++ b/src/main/java/org/opensearch/ad/constant/ADCommonName.java @@ -34,9 +34,6 @@ public class ADCommonName { // Anomaly Detector name for X-Opaque-Id header // ====================================== public static final String ANOMALY_DETECTOR = "[Anomaly Detector]"; - public static final String AD_PLUGIN_NAME = "opensearch-anomaly-detection"; - public static final String AD_PLUGIN_NAME_FOR_TEST = "org.opensearch.ad.AnomalyDetectorPlugin"; - public static final String AD_PLUGIN_VERSION_FOR_TEST = "NA"; // ====================================== // Ultrawarm node attributes diff --git a/src/main/java/org/opensearch/ad/constant/ADCommonName.java-e b/src/main/java/org/opensearch/ad/constant/ADCommonName.java-e new file mode 100644 index 000000000..3a97db889 --- /dev/null +++ b/src/main/java/org/opensearch/ad/constant/ADCommonName.java-e @@ -0,0 +1,105 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.constant; + +import org.opensearch.timeseries.stats.StatNames; + +public class ADCommonName { + // ====================================== + // Index name + // ====================================== + // index name for anomaly checkpoint of each model. One model one document. + public static final String CHECKPOINT_INDEX_NAME = ".opendistro-anomaly-checkpoints"; + // index name for anomaly detection state. Will store AD task in this index as well. + public static final String DETECTION_STATE_INDEX = ".opendistro-anomaly-detection-state"; + + // The alias of the index in which to write AD result history + public static final String ANOMALY_RESULT_INDEX_ALIAS = ".opendistro-anomaly-results"; + + // ====================================== + // Format name + // ====================================== + public static final String EPOCH_MILLIS_FORMAT = "epoch_millis"; + + // ====================================== + // Anomaly Detector name for X-Opaque-Id header + // ====================================== + public static final String ANOMALY_DETECTOR = "[Anomaly Detector]"; + + // ====================================== + // Ultrawarm node attributes + // ====================================== + + // hot node + public static String HOT_BOX_TYPE = "hot"; + + // warm node + public static String WARM_BOX_TYPE = "warm"; + + // box type + public static final String BOX_TYPE_KEY = "box_type"; + + // ====================================== + // Profile name + // ====================================== + public static final String STATE = "state"; + public static final String ERROR = "error"; + public static final String COORDINATING_NODE = "coordinating_node"; + public static final String SHINGLE_SIZE = "shingle_size"; + public static final String TOTAL_SIZE_IN_BYTES = "total_size_in_bytes"; + public static final String MODELS = "models"; + public static final String MODEL = "model"; + public static final String INIT_PROGRESS = "init_progress"; + public static final String CATEGORICAL_FIELD = "category_field"; + public static final String TOTAL_ENTITIES = "total_entities"; + public static final String ACTIVE_ENTITIES = "active_entities"; + public static final String ENTITY_INFO = "entity_info"; + public static final String TOTAL_UPDATES = "total_updates"; + public static final String MODEL_COUNT = StatNames.MODEL_COUNT.getName(); + // ====================================== + // Historical detectors + // ====================================== + public static final String AD_TASK = "ad_task"; + public static final String HISTORICAL_ANALYSIS = "historical_analysis"; + public static final String AD_TASK_REMOTE = "ad_task_remote"; + public static final String CANCEL_TASK = "cancel_task"; + + // ====================================== + // Used in stats API + // ====================================== + public static final String DETECTOR_ID_KEY = "detector_id"; + + // ====================================== + // Used in toXContent + // ====================================== + public static final String RCF_SCORE_JSON_KEY = "rCFScore"; + public static final String ID_JSON_KEY = "adID"; + public static final String FEATURE_JSON_KEY = "features"; + public static final String CONFIDENCE_JSON_KEY = "confidence"; + public static final String ANOMALY_GRADE_JSON_KEY = "anomalyGrade"; + public static final String QUEUE_JSON_KEY = "queue"; + // ====================================== + // Used for backward-compatibility in messaging + // ====================================== + public static final String EMPTY_FIELD = ""; + + // Validation + // ====================================== + // detector validation aspect + public static final String DETECTOR_ASPECT = "detector"; + // ====================================== + // Used for custom AD result index + // ====================================== + public static final String DUMMY_AD_RESULT_ID = "dummy_ad_result_id"; + public static final String DUMMY_DETECTOR_ID = "dummy_detector_id"; + public static final String CUSTOM_RESULT_INDEX_PREFIX = "opensearch-ad-plugin-result-"; +} diff --git a/src/main/java/org/opensearch/ad/constant/CommonValue.java-e b/src/main/java/org/opensearch/ad/constant/CommonValue.java-e new file mode 100644 index 000000000..f5d5b15eb --- /dev/null +++ b/src/main/java/org/opensearch/ad/constant/CommonValue.java-e @@ -0,0 +1,19 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.constant; + +public class CommonValue { + // unknown or no schema version + public static Integer NO_SCHEMA_VERSION = 0; + public static String INTERNAL_ACTION_PREFIX = "cluster:admin/opendistro/adinternal/"; + public static String EXTERNAL_ACTION_PREFIX = "cluster:admin/opendistro/ad/"; +} diff --git a/src/main/java/org/opensearch/ad/feature/AbstractRetriever.java-e b/src/main/java/org/opensearch/ad/feature/AbstractRetriever.java-e new file mode 100644 index 000000000..886dbcbc4 --- /dev/null +++ b/src/main/java/org/opensearch/ad/feature/AbstractRetriever.java-e @@ -0,0 +1,119 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.feature; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; + +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.aggregations.bucket.InternalSingleBucketAggregation; +import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregationBuilder; +import org.opensearch.search.aggregations.metrics.InternalTDigestPercentiles; +import org.opensearch.search.aggregations.metrics.NumericMetricsAggregation.SingleValue; +import org.opensearch.search.aggregations.metrics.Percentile; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.common.exception.EndRunException; + +public abstract class AbstractRetriever { + protected double parseAggregation(Aggregation aggregationToParse) { + Double result = null; + /* example InternalSingleBucketAggregation: filter aggregation like + "t_shirts": { + "filter": { + "bool": { + "should": [ + { + "term": { + "issueType": "foo" + } + } + ... + ], + "minimum_should_match": "1", + "boost": 1 + } + }, + "aggs": { + "impactUniqueAccounts": { + "aggregation": { + "field": "account" + } + } + } + } + + would produce an InternalFilter (a subtype of InternalSingleBucketAggregation) with a sub-aggregation + InternalCardinality that is also a SingleValue + */ + + if (aggregationToParse instanceof InternalSingleBucketAggregation) { + InternalAggregations bucket = ((InternalSingleBucketAggregation) aggregationToParse).getAggregations(); + if (bucket != null) { + List aggrs = bucket.asList(); + if (aggrs.size() == 1) { + // we only accept a single value as feature + aggregationToParse = aggrs.get(0); + } + } + } + + final Aggregation aggregation = aggregationToParse; + if (aggregation instanceof SingleValue) { + result = ((SingleValue) aggregation).value(); + } else if (aggregation instanceof InternalTDigestPercentiles) { + Iterator percentile = ((InternalTDigestPercentiles) aggregation).iterator(); + if (percentile.hasNext()) { + result = percentile.next().getValue(); + } + } + return Optional + .ofNullable(result) + .orElseThrow(() -> new EndRunException("Failed to parse aggregation " + aggregation, true).countedInStats(false)); + } + + protected Optional parseBucket(MultiBucketsAggregation.Bucket bucket, List featureIds) { + return parseAggregations(Optional.ofNullable(bucket).map(b -> b.getAggregations()), featureIds); + } + + protected Optional parseAggregations(Optional aggregations, List featureIds) { + return aggregations + .map(aggs -> aggs.asMap()) + .map( + map -> featureIds + .stream() + .mapToDouble(id -> Optional.ofNullable(map.get(id)).map(this::parseAggregation).orElse(Double.NaN)) + .toArray() + ) + .filter(result -> Arrays.stream(result).noneMatch(d -> Double.isNaN(d) || Double.isInfinite(d))); + } + + protected void updateSourceAfterKey(Map afterKey, SearchSourceBuilder search) { + AggregationBuilder aggBuilder = search.aggregations().getAggregatorFactories().iterator().next(); + // update after-key with the new value + if (aggBuilder instanceof CompositeAggregationBuilder) { + CompositeAggregationBuilder comp = (CompositeAggregationBuilder) aggBuilder; + comp.aggregateAfter(afterKey); + } else { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Invalid client request; expected a composite builder but instead got {}", aggBuilder) + ); + } + } +} diff --git a/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java-e b/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java-e new file mode 100644 index 000000000..d41bdf76e --- /dev/null +++ b/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java-e @@ -0,0 +1,425 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.feature; + +import java.io.IOException; +import java.time.Clock; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +import org.apache.commons.lang.builder.ToStringBuilder; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation.Bucket; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregationBuilder; +import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.util.ParseUtils; + +/** + * + * Use pagination to fetch entities. If there are more than pageSize entities, + * we will fetch them in the next page. We implement pagination with composite query. + * Results are decomposed into pages. Each page encapsulates aggregated values for + * each entity and is sent to model nodes according to the hash ring mapping from + * entity model Id to a data node. + * + */ +public class CompositeRetriever extends AbstractRetriever { + public static final String AGG_NAME_COMP = "comp_agg"; + private static final Logger LOG = LogManager.getLogger(CompositeRetriever.class); + + private final long dataStartEpoch; + private final long dataEndEpoch; + private final AnomalyDetector anomalyDetector; + private final NamedXContentRegistry xContent; + private final Client client; + private final SecurityClientUtil clientUtil; + private int totalResults; + // we can process at most maxEntities entities + private int maxEntities; + private final int pageSize; + private long expirationEpochMs; + private Clock clock; + private IndexNameExpressionResolver indexNameExpressionResolver; + private ClusterService clusterService; + + public CompositeRetriever( + long dataStartEpoch, + long dataEndEpoch, + AnomalyDetector anomalyDetector, + NamedXContentRegistry xContent, + Client client, + SecurityClientUtil clientUtil, + long expirationEpochMs, + Clock clock, + Settings settings, + int maxEntitiesPerInterval, + int pageSize, + IndexNameExpressionResolver indexNameExpressionResolver, + ClusterService clusterService + ) { + this.dataStartEpoch = dataStartEpoch; + this.dataEndEpoch = dataEndEpoch; + this.anomalyDetector = anomalyDetector; + this.xContent = xContent; + this.client = client; + this.clientUtil = clientUtil; + this.totalResults = 0; + this.maxEntities = maxEntitiesPerInterval; + this.pageSize = pageSize; + this.expirationEpochMs = expirationEpochMs; + this.clock = clock; + this.indexNameExpressionResolver = indexNameExpressionResolver; + this.clusterService = clusterService; + } + + // a constructor that provide default value of clock + public CompositeRetriever( + long dataStartEpoch, + long dataEndEpoch, + AnomalyDetector anomalyDetector, + NamedXContentRegistry xContent, + Client client, + SecurityClientUtil clientUtil, + long expirationEpochMs, + Settings settings, + int maxEntitiesPerInterval, + int pageSize, + IndexNameExpressionResolver indexNameExpressionResolver, + ClusterService clusterService + ) { + this( + dataStartEpoch, + dataEndEpoch, + anomalyDetector, + xContent, + client, + clientUtil, + expirationEpochMs, + Clock.systemUTC(), + settings, + maxEntitiesPerInterval, + pageSize, + indexNameExpressionResolver, + clusterService + ); + } + + /** + * @return an iterator over pages + * @throws IOException - if we cannot construct valid queries according to + * detector definition + */ + public PageIterator iterator() throws IOException { + RangeQueryBuilder rangeQuery = new RangeQueryBuilder(anomalyDetector.getTimeField()) + .gte(dataStartEpoch) + .lt(dataEndEpoch) + .format("epoch_millis"); + + BoolQueryBuilder internalFilterQuery = new BoolQueryBuilder().filter(anomalyDetector.getFilterQuery()).filter(rangeQuery); + + // multiple categorical fields are supported + CompositeAggregationBuilder composite = AggregationBuilders + .composite( + AGG_NAME_COMP, + anomalyDetector.getCategoryFields().stream().map(f -> new TermsValuesSourceBuilder(f).field(f)).collect(Collectors.toList()) + ) + .size(pageSize); + for (Feature feature : anomalyDetector.getFeatureAttributes()) { + AggregatorFactories.Builder internalAgg = ParseUtils + .parseAggregators(feature.getAggregation().toString(), xContent, feature.getId()); + composite.subAggregation(internalAgg.getAggregatorFactories().iterator().next()); + } + + // In order to optimize the early termination it is advised to set track_total_hits in the request to false. + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(internalFilterQuery) + .size(0) + .aggregation(composite) + .trackTotalHits(false); + + return new PageIterator(searchSourceBuilder); + } + + public class PageIterator { + private SearchSourceBuilder source; + // a map from categorical field name to values (type: java.lang.Comparable) + private Map afterKey; + // number of iterations so far + private int iterations; + private long startMs; + + public PageIterator(SearchSourceBuilder source) { + this.source = source; + this.afterKey = null; + this.iterations = 0; + this.startMs = clock.millis(); + } + + /** + * Results are returned using listener + * @param listener Listener to return results + */ + public void next(ActionListener listener) { + iterations++; + + // inject user role while searching. + + SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0]), source); + final ActionListener searchResponseListener = new ActionListener() { + @Override + public void onResponse(SearchResponse response) { + processResponse(response, () -> client.search(searchRequest, this), listener); + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }; + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + anomalyDetector.getId(), + client, + searchResponseListener + ); + } + + private void processResponse(SearchResponse response, Runnable retry, ActionListener listener) { + try { + if (shouldRetryDueToEmptyPage(response)) { + updateCompositeAfterKey(response, source); + retry.run(); + return; + } + + Page page = analyzePage(response); + if (afterKey != null) { + updateCompositeAfterKey(response, source); + } + listener.onResponse(page); + } catch (Exception ex) { + listener.onFailure(ex); + } + } + + /** + * + * @param response current response + * @return A page containing + * ** the after key + * ** query source builder to next page if any + * ** a map of composite keys to its values. The values are arranged + * according to the order of anomalyDetector.getEnabledFeatureIds(). + */ + private Page analyzePage(SearchResponse response) { + Optional compositeOptional = getComposite(response); + + if (false == compositeOptional.isPresent()) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Empty resposne: %s", response)); + } + + CompositeAggregation composite = compositeOptional.get(); + Map results = new HashMap<>(); + /* + * + * Example composite aggregation: + * + "aggregations": { + "my_buckets": { + "after_key": { + "service": "app_6", + "host": "server_3" + }, + "buckets": [ + { + "key": { + "service": "app_6", + "host": "server_3" + }, + "doc_count": 1, + "the_max": { + "value": -38.0 + }, + "the_min": { + "value": -38.0 + } + } + ] + } + } + */ + for (Bucket bucket : composite.getBuckets()) { + Optional featureValues = parseBucket(bucket, anomalyDetector.getEnabledFeatureIds()); + // bucket.getKey() returns a map of categorical field like "host" and its value like "server_1" + if (featureValues.isPresent() && bucket.getKey() != null) { + results.put(Entity.createEntityByReordering(bucket.getKey()), featureValues.get()); + } + } + + totalResults += results.size(); + + afterKey = composite.afterKey(); + return new Page(results); + } + + private void updateCompositeAfterKey(SearchResponse r, SearchSourceBuilder search) { + Optional composite = getComposite(r); + + if (false == composite.isPresent()) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Empty resposne: %s", r)); + } + + updateSourceAfterKey(composite.get().afterKey(), search); + } + + private boolean shouldRetryDueToEmptyPage(SearchResponse response) { + Optional composite = getComposite(response); + // if there are no buckets but a next page, go fetch it instead of sending an empty response to the client + if (false == composite.isPresent()) { + return false; + } + CompositeAggregation aggr = composite.get(); + return aggr.getBuckets().isEmpty() && aggr.afterKey() != null && !aggr.afterKey().isEmpty(); + } + + Optional getComposite(SearchResponse response) { + // When the source index is a regex like blah*, we will get empty response like + // the following even if no index starting with blah exists. + // {"took":0,"timed_out":false,"_shards":{"total":0,"successful":0,"skipped":0,"failed":0},"hits":{"max_score":0.0,"hits":[]}} + // Without regex, we will get IndexNotFoundException instead. + // {"error":{"root_cause":[{"type":"index_not_found_exception","reason":"no such + // index + // [blah]","index":"blah","resource.id":"blah","resource.type":"index_or_alias","index_uuid":"_na_"}],"type":"index_not_found_exception","reason":"no + // such index + // [blah]","index":"blah","resource.id":"blah","resource.type":"index_or_alias","index_uuid":"_na_"},"status":404}% + if (response == null || response.getAggregations() == null) { + List sourceIndices = anomalyDetector.getIndices(); + String[] concreteIndices = indexNameExpressionResolver + .concreteIndexNames(clusterService.state(), IndicesOptions.lenientExpandOpen(), sourceIndices.toArray(new String[0])); + if (concreteIndices.length == 0) { + throw new IndexNotFoundException(String.join(",", sourceIndices)); + } else { + return Optional.empty(); + } + } + Aggregation agg = response.getAggregations().get(AGG_NAME_COMP); + if (agg == null) { + // when current interval has no data + return Optional.empty(); + } + + if (agg instanceof CompositeAggregation) { + return Optional.of((CompositeAggregation) agg); + } + + throw new IllegalArgumentException(String.format(Locale.ROOT, "Not a composite response; {}", agg.getClass())); + } + + /** + * Whether next page exists. Conditions are: + * 1) this is the first time we query (iterations == 0) or afterKey is not null + * 2) next detection interval has not started + * @return true if the iteration has more pages. + */ + public boolean hasNext() { + long now = clock.millis(); + if (expirationEpochMs <= now) { + LOG + .debug( + new ParameterizedMessage( + "Time is up, afterKey: [{}], expirationEpochMs: [{}], now [{}]", + afterKey, + expirationEpochMs, + now + ) + ); + } + if ((iterations > 0 && afterKey == null) || totalResults > maxEntities) { + LOG.debug(new ParameterizedMessage("Finished in [{}] msecs. ", (now - startMs))); + } + return (iterations == 0 || (totalResults > 0 && afterKey != null)) && expirationEpochMs > now && totalResults <= maxEntities; + } + + @Override + public String toString() { + ToStringBuilder toStringBuilder = new ToStringBuilder(this); + + if (afterKey != null) { + toStringBuilder.append("afterKey", afterKey); + } + if (source != null) { + toStringBuilder.append("source", source); + } + + return toStringBuilder.toString(); + } + } + + public class Page { + + Map results; + + public Page(Map results) { + this.results = results; + } + + public boolean isEmpty() { + return results == null || results.isEmpty(); + } + + public Map getResults() { + return results; + } + + @Override + public String toString() { + ToStringBuilder toStringBuilder = new ToStringBuilder(this); + + if (results != null) { + toStringBuilder.append("results", results); + } + + return toStringBuilder.toString(); + } + } +} diff --git a/src/main/java/org/opensearch/ad/feature/FeatureManager.java-e b/src/main/java/org/opensearch/ad/feature/FeatureManager.java-e new file mode 100644 index 000000000..f6fd8ded0 --- /dev/null +++ b/src/main/java/org/opensearch/ad/feature/FeatureManager.java-e @@ -0,0 +1,697 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.feature; + +import static java.util.Arrays.copyOfRange; +import static org.apache.commons.math3.linear.MatrixUtils.createRealMatrix; + +import java.io.IOException; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.AbstractMap; +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.TreeMap; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.LongStream; +import java.util.stream.Stream; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ThreadedActionListener; +import org.opensearch.ad.CleanState; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.dataprocessor.Imputer; +import org.opensearch.timeseries.model.Entity; + +/** + * A facade managing feature data operations and buffers. + */ +public class FeatureManager implements CleanState { + + private static final Logger logger = LogManager.getLogger(FeatureManager.class); + + // Each anomaly detector has a queue of data points with timestamps (in epoch milliseconds). + private final Map>>> detectorIdsToTimeShingles; + + private final SearchFeatureDao searchFeatureDao; + private final Imputer imputer; + private final Clock clock; + + private final int maxTrainSamples; + private final int maxSampleStride; + private final int trainSampleTimeRangeInHours; + private final int minTrainSamples; + private final double maxMissingPointsRate; + private final int maxNeighborDistance; + private final double previewSampleRate; + private final int maxPreviewSamples; + private final Duration featureBufferTtl; + private final ThreadPool threadPool; + private final String adThreadPoolName; + + /** + * Constructor with dependencies and configuration. + * + * @param searchFeatureDao DAO of features from search + * @param imputer imputer of samples + * @param clock clock for system time + * @param maxTrainSamples max number of samples from search + * @param maxSampleStride max stride between uninterpolated train samples + * @param trainSampleTimeRangeInHours time range in hours for collect train samples + * @param minTrainSamples min number of train samples + * @param maxMissingPointsRate max proportion of shingle with missing points allowed to generate a shingle + * @param maxNeighborDistance max distance (number of intervals) between a missing point and a replacement neighbor + * @param previewSampleRate number of samples to number of all the data points in the preview time range + * @param maxPreviewSamples max number of samples from search for preview features + * @param featureBufferTtl time to live for stale feature buffers + * @param threadPool object through which we can invoke different threadpool using different names + * @param adThreadPoolName AD threadpool's name + */ + public FeatureManager( + SearchFeatureDao searchFeatureDao, + Imputer imputer, + Clock clock, + int maxTrainSamples, + int maxSampleStride, + int trainSampleTimeRangeInHours, + int minTrainSamples, + double maxMissingPointsRate, + int maxNeighborDistance, + double previewSampleRate, + int maxPreviewSamples, + Duration featureBufferTtl, + ThreadPool threadPool, + String adThreadPoolName + ) { + this.searchFeatureDao = searchFeatureDao; + this.imputer = imputer; + this.clock = clock; + this.maxTrainSamples = maxTrainSamples; + this.maxSampleStride = maxSampleStride; + this.trainSampleTimeRangeInHours = trainSampleTimeRangeInHours; + this.minTrainSamples = minTrainSamples; + this.maxMissingPointsRate = maxMissingPointsRate; + this.maxNeighborDistance = maxNeighborDistance; + this.previewSampleRate = previewSampleRate; + this.maxPreviewSamples = maxPreviewSamples; + this.featureBufferTtl = featureBufferTtl; + + this.detectorIdsToTimeShingles = new ConcurrentHashMap<>(); + this.threadPool = threadPool; + this.adThreadPoolName = adThreadPoolName; + } + + /** + * Returns to listener unprocessed features and processed features (such as shingle) for the current data point. + * The listener's onFailure is called with EndRunException on feature query creation errors. + * + * This method sends a single query for historical data for data points (including the current point) that are missing + * from the shingle, and updates the shingle which is persisted to future calls to this method for subsequent time + * intervals. To allow for time variations/delays, an interval is considered missing from the shingle if no data point + * is found within half an interval away. See doc for updateUnprocessedFeatures for details on how the shingle is + * updated. + * + * @param detector anomaly detector for which the features are returned + * @param startTime start time of the data point in epoch milliseconds + * @param endTime end time of the data point in epoch milliseconds + * @param listener onResponse is called with unprocessed features and processed features for the current data point + */ + public void getCurrentFeatures(AnomalyDetector detector, long startTime, long endTime, ActionListener listener) { + + int shingleSize = detector.getShingleSize(); + Deque>> shingle = detectorIdsToTimeShingles + .computeIfAbsent(detector.getId(), id -> new ArrayDeque<>(shingleSize)); + + // To allow for small time variations/delays in running the detector. + long maxTimeDifference = detector.getIntervalInMilliseconds() / 2; + Map>> featuresMap = getNearbyPointsForShingle(detector, shingle, endTime, maxTimeDifference) + .collect(Collectors.toMap(Entry::getKey, Entry::getValue)); + + List> missingRanges = getMissingRangesInShingle(detector, featuresMap, endTime); + + if (missingRanges.size() > 0) { + try { + searchFeatureDao.getFeatureSamplesForPeriods(detector, missingRanges, ActionListener.wrap(points -> { + for (int i = 0; i < points.size(); i++) { + Optional point = points.get(i); + long rangeEndTime = missingRanges.get(i).getValue(); + featuresMap.put(rangeEndTime, new SimpleImmutableEntry<>(rangeEndTime, point)); + } + updateUnprocessedFeatures(detector, shingle, featuresMap, endTime, listener); + }, listener::onFailure)); + } catch (IOException e) { + listener.onFailure(new EndRunException(detector.getId(), CommonMessages.INVALID_SEARCH_QUERY_MSG, e, true)); + } + } else { + listener.onResponse(getProcessedFeatures(shingle, detector, endTime)); + } + } + + private List> getMissingRangesInShingle( + AnomalyDetector detector, + Map>> featuresMap, + long endTime + ) { + long intervalMilli = detector.getIntervalInMilliseconds(); + int shingleSize = detector.getShingleSize(); + return getFullShingleEndTimes(endTime, intervalMilli, shingleSize) + .filter(time -> !featuresMap.containsKey(time)) + .mapToObj(time -> new SimpleImmutableEntry<>(time - intervalMilli, time)) + .collect(Collectors.toList()); + } + + /** + * Updates the shingle to contain one Optional data point for each of shingleSize consecutive time intervals, ending + * with the current interval. Each entry in the shingle contains the timestamp of the data point as the key, and the + * data point wrapped in an Optional. If the data point is missing (even after querying, since this method is invoked + * after querying), an entry with an empty Optional value is stored in the shingle to prevent subsequent calls to + * getCurrentFeatures from re-querying the missing data point again. + * + * Note that in the presence of time variations/delays up to half an interval, the shingle stores the actual original + * end times of the data points, not the computed end times that were calculated based on the current endTime. + * Ex: if data points are queried at times 100, 201, 299, then the shingle will contain [100: x, 201: y, 299: z]. + * + * @param detector anomaly detector for which the features are returned. + * @param shingle buffer which persists the past shingleSize data points to subsequent calls of getCurrentFeature. + * Each entry contains the timestamp of the data point and an optional data point value. + * @param featuresMap A map where the keys are the computed millisecond timestamps associated with intervals in the + * shingle, and the values are entries that contain the actual timestamp of the data point and + * an optional data point value. + * @param listener onResponse is called with unprocessed features and processed features for the current data point. + */ + private void updateUnprocessedFeatures( + AnomalyDetector detector, + Deque>> shingle, + Map>> featuresMap, + long endTime, + ActionListener listener + ) { + shingle.clear(); + getFullShingleEndTimes(endTime, detector.getIntervalInMilliseconds(), detector.getShingleSize()) + .mapToObj(time -> featuresMap.getOrDefault(time, new SimpleImmutableEntry<>(time, Optional.empty()))) + .forEach(e -> shingle.add(e)); + + listener.onResponse(getProcessedFeatures(shingle, detector, endTime)); + } + + private double[][] filterAndFill(Deque>> shingle, long endTime, AnomalyDetector detector) { + int shingleSize = detector.getShingleSize(); + Deque>> filteredShingle = shingle + .stream() + .filter(e -> e.getValue().isPresent()) + .collect(Collectors.toCollection(ArrayDeque::new)); + double[][] result = null; + if (filteredShingle.size() >= shingleSize - getMaxMissingPoints(shingleSize)) { + // Imputes missing data points with the values of neighboring data points. + long maxMillisecondsDifference = maxNeighborDistance * detector.getIntervalInMilliseconds(); + result = getNearbyPointsForShingle(detector, filteredShingle, endTime, maxMillisecondsDifference) + .map(e -> e.getValue().getValue().orElse(null)) + .filter(d -> d != null) + .toArray(double[][]::new); + + if (result.length < shingleSize) { + result = null; + } + } + return result; + } + + /** + * Helper method that associates data points (along with their actual timestamps) to the intervals of a full shingle. + * + * Depending on the timestamp tolerance (maxMillisecondsDifference), this can be used to allow for small time + * variations/delays in running the detector, or used for imputing missing points in the shingle with neighboring points. + * + * @return A stream of entries, where the key is the computed millisecond timestamp associated with an interval in + * the shingle, and the value is an entry that contains the actual timestamp of the data point and an optional data + * point value. + */ + private Stream>>> getNearbyPointsForShingle( + AnomalyDetector detector, + Deque>> shingle, + long endTime, + long maxMillisecondsDifference + ) { + long intervalMilli = detector.getIntervalInMilliseconds(); + int shingleSize = detector.getShingleSize(); + TreeMap> search = new TreeMap<>( + shingle.stream().collect(Collectors.toMap(Entry::getKey, Entry::getValue)) + ); + return getFullShingleEndTimes(endTime, intervalMilli, shingleSize).mapToObj(t -> { + Optional>> after = Optional.ofNullable(search.ceilingEntry(t)); + Optional>> before = Optional.ofNullable(search.floorEntry(t)); + return after + .filter(a -> Math.abs(t - a.getKey()) <= before.map(b -> Math.abs(t - b.getKey())).orElse(Long.MAX_VALUE)) + .map(Optional::of) + .orElse(before) + .filter(e -> Math.abs(t - e.getKey()) < maxMillisecondsDifference) + .map(e -> new SimpleImmutableEntry<>(t, e)); + }).filter(Optional::isPresent).map(Optional::get); + } + + private LongStream getFullShingleEndTimes(long endTime, long intervalMilli, int shingleSize) { + return LongStream.rangeClosed(1, shingleSize).map(i -> endTime - (shingleSize - i) * intervalMilli); + } + + /** + * Returns to listener data for cold-start training. + * + * Training data starts with getting samples from (costly) search. + * Samples are increased in dimension via shingling. + * + * @param detector contains data info (indices, documents, etc) + * @param listener onResponse is called with data for cold-start training, or empty if unavailable + * onFailure is called with EndRunException on feature query creation errors + */ + public void getColdStartData(AnomalyDetector detector, ActionListener> listener) { + ActionListener> latestTimeListener = ActionListener + .wrap(latest -> getColdStartSamples(latest, detector, listener), listener::onFailure); + searchFeatureDao + .getLatestDataTime(detector, new ThreadedActionListener<>(logger, threadPool, adThreadPoolName, latestTimeListener, false)); + } + + private void getColdStartSamples(Optional latest, AnomalyDetector detector, ActionListener> listener) { + int shingleSize = detector.getShingleSize(); + if (latest.isPresent()) { + List> sampleRanges = getColdStartSampleRanges(detector, latest.get()); + try { + ActionListener>> getFeaturesListener = ActionListener + .wrap(samples -> processColdStartSamples(samples, shingleSize, listener), listener::onFailure); + searchFeatureDao + .getFeatureSamplesForPeriods( + detector, + sampleRanges, + new ThreadedActionListener<>(logger, threadPool, adThreadPoolName, getFeaturesListener, false) + ); + } catch (IOException e) { + listener.onFailure(new EndRunException(detector.getId(), CommonMessages.INVALID_SEARCH_QUERY_MSG, e, true)); + } + } else { + listener.onResponse(Optional.empty()); + } + } + + private void processColdStartSamples(List> samples, int shingleSize, ActionListener> listener) { + List shingles = new ArrayList<>(); + LinkedList> currentShingle = new LinkedList<>(); + for (Optional sample : samples) { + currentShingle.addLast(sample); + if (currentShingle.size() == shingleSize) { + sample.ifPresent(s -> fillAndShingle(currentShingle, shingleSize).ifPresent(shingles::add)); + currentShingle.remove(); + } + } + listener.onResponse(Optional.of(shingles.toArray(new double[0][0])).filter(results -> results.length > 0)); + } + + private Optional fillAndShingle(LinkedList> shingle, int shingleSize) { + Optional result = null; + if (shingle.stream().filter(s -> s.isPresent()).count() >= shingleSize - getMaxMissingPoints(shingleSize)) { + TreeMap search = new TreeMap<>( + IntStream + .range(0, shingleSize) + .filter(i -> shingle.get(i).isPresent()) + .boxed() + .collect(Collectors.toMap(i -> i, i -> shingle.get(i).get())) + ); + result = Optional.of(IntStream.range(0, shingleSize).mapToObj(i -> { + Optional> after = Optional.ofNullable(search.ceilingEntry(i)); + Optional> before = Optional.ofNullable(search.floorEntry(i)); + return after + .filter(a -> Math.abs(i - a.getKey()) <= before.map(b -> Math.abs(i - b.getKey())).orElse(Integer.MAX_VALUE)) + .map(Optional::of) + .orElse(before) + .filter(e -> Math.abs(i - e.getKey()) <= maxNeighborDistance) + .map(Entry::getValue) + .orElse(null); + }).filter(d -> d != null).toArray(double[][]::new)) + .filter(d -> d.length == shingleSize) + .map(d -> batchShingle(d, shingleSize)[0]); + } else { + result = Optional.empty(); + } + return result; + } + + private List> getColdStartSampleRanges(AnomalyDetector detector, long endMillis) { + long interval = detector.getIntervalInMilliseconds(); + int numSamples = Math.max((int) (Duration.ofHours(this.trainSampleTimeRangeInHours).toMillis() / interval), this.minTrainSamples); + return IntStream + .rangeClosed(1, numSamples) + .mapToObj(i -> new SimpleImmutableEntry<>(endMillis - (numSamples - i + 1) * interval, endMillis - (numSamples - i) * interval)) + .collect(Collectors.toList()); + } + + /** + * Shingles a batch of data points by concatenating neighboring data points. + * + * @param points M, N where M is the number of data points and N is the dimension of a point + * @param shingleSize the size of a shingle + * @return P, Q where P = M - {@code shingleSize} + 1 and Q = N * {@code shingleSize} + * @throws IllegalArgumentException when input is invalid + */ + public double[][] batchShingle(double[][] points, int shingleSize) { + if (points.length == 0 || points[0].length == 0 || points.length < shingleSize || shingleSize < 1) { + throw new IllegalArgumentException("Invalid data for shingling."); + } + int numPoints = points.length; + int dimPoint = points[0].length; + int numShingles = numPoints - shingleSize + 1; + int dimShingle = dimPoint * shingleSize; + double[][] shingles = new double[numShingles][dimShingle]; + for (int i = 0; i < numShingles; i++) { + for (int j = 0; j < shingleSize; j++) { + System.arraycopy(points[i + j], 0, shingles[i], j * dimPoint, dimPoint); + } + } + return shingles; + } + + /** + * Deletes managed features for the detector. + * + * @param detectorId ID of the detector + */ + @Override + public void clear(String detectorId) { + detectorIdsToTimeShingles.remove(detectorId); + } + + /** + * Does maintenance work. + * + * The current implementation removes feature buffers that are updated more than ttlFeatureBuffer (3 days for example) ago. + * The cleanup is needed since feature buffers are not explicitly deleted after a detector is deleted or relocated. + */ + public void maintenance() { + try { + detectorIdsToTimeShingles + .entrySet() + .removeIf( + idQueue -> Optional + .ofNullable(idQueue.getValue().peekLast()) + .map(p -> Instant.ofEpochMilli(p.getKey()).plus(featureBufferTtl).isBefore(clock.instant())) + .orElse(true) + ); + } catch (Exception e) { + logger.warn("Caught exception during maintenance", e); + } + } + + /** + * Returns the entities for preview to listener + * @param detector detector config + * @param startTime start of the range in epoch milliseconds + * @param endTime end of the range in epoch milliseconds + * @param listener onResponse is called when entities are found + */ + public void getPreviewEntities(AnomalyDetector detector, long startTime, long endTime, ActionListener> listener) { + searchFeatureDao.getHighestCountEntities(detector, startTime, endTime, listener); + } + + /** + * Returns to listener feature data points (unprocessed and processed) from the period for preview purpose for specific entity. + * + * Due to the constraints (workload, latency) from preview, a small number of data samples are from actual + * query results and the remaining are from interpolation. The results are approximate to the actual features. + * + * @param detector detector info containing indices, features, interval, etc + * @param entity entity specified + * @param startMilli start of the range in epoch milliseconds + * @param endMilli end of the range in epoch milliseconds + * @param listener onResponse is called with time ranges, unprocessed features, + * and processed features of the data points from the period + * onFailure is called with IllegalArgumentException when there is no data to preview + * @throws IOException if a user gives wrong query input when defining a detector + */ + public void getPreviewFeaturesForEntity( + AnomalyDetector detector, + Entity entity, + long startMilli, + long endMilli, + ActionListener listener + ) throws IOException { + // TODO refactor this common lines so that these code can be run for 1 time for all entities + Entry>, Integer> sampleRangeResults = getSampleRanges(detector, startMilli, endMilli); + List> sampleRanges = sampleRangeResults.getKey(); + int stride = sampleRangeResults.getValue(); + int shingleSize = detector.getShingleSize(); + + getPreviewSamplesInRangesForEntity(detector, sampleRanges, entity, getFeatureSamplesListener(stride, shingleSize, listener)); + } + + private ActionListener>, double[][]>> getFeatureSamplesListener( + int stride, + int shingleSize, + ActionListener listener + ) { + return ActionListener.wrap(samples -> { + List> searchTimeRange = samples.getKey(); + if (searchTimeRange.size() == 0) { + listener.onFailure(new IllegalArgumentException("No data to preview anomaly detection.")); + return; + } + double[][] sampleFeatures = samples.getValue(); + List> previewRanges = getPreviewRanges(searchTimeRange, stride, shingleSize); + Entry previewFeatures = getPreviewFeatures(sampleFeatures, stride, shingleSize); + listener.onResponse(new Features(previewRanges, previewFeatures.getKey(), previewFeatures.getValue())); + }, listener::onFailure); + } + + /** + * Returns to listener feature data points (unprocessed and processed) from the period for preview purpose. + * + * Due to the constraints (workload, latency) from preview, a small number of data samples are from actual + * query results and the remaining are from interpolation. The results are approximate to the actual features. + * + * @param detector detector info containing indices, features, interval, etc + * @param startMilli start of the range in epoch milliseconds + * @param endMilli end of the range in epoch milliseconds + * @param listener onResponse is called with time ranges, unprocessed features, + * and processed features of the data points from the period + * onFailure is called with IllegalArgumentException when there is no data to preview + * @throws IOException if a user gives wrong query input when defining a detector + */ + public void getPreviewFeatures(AnomalyDetector detector, long startMilli, long endMilli, ActionListener listener) + throws IOException { + Entry>, Integer> sampleRangeResults = getSampleRanges(detector, startMilli, endMilli); + List> sampleRanges = sampleRangeResults.getKey(); + int stride = sampleRangeResults.getValue(); + int shingleSize = detector.getShingleSize(); + + getSamplesForRanges(detector, sampleRanges, getFeatureSamplesListener(stride, shingleSize, listener)); + } + + /** + * Gets time ranges of sampled data points. + * + * To reduce workload/latency from search, most data points in the preview time ranges are not from search results. + * This implementation selects up to maxPreviewSamples evenly spaced points from the entire time range. + * + * @return key is a list of sampled time ranges, value is the stride between samples + */ + private Entry>, Integer> getSampleRanges(AnomalyDetector detector, long startMilli, long endMilli) { + long start = truncateToMinute(startMilli); + long end = truncateToMinute(endMilli); + long bucketSize = detector.getIntervalInMilliseconds(); + int numBuckets = (int) Math.floor((end - start) / (double) bucketSize); + int numSamples = (int) Math.max(Math.min(numBuckets * previewSampleRate, maxPreviewSamples), 1); + int stride = (int) Math.max(1, Math.floor((double) numBuckets / numSamples)); + int numStrides = (int) Math.ceil(numBuckets / (double) stride); + List> sampleRanges = Stream + .iterate(start, i -> i + stride * bucketSize) + .limit(numStrides) + .map(time -> new SimpleImmutableEntry<>(time, time + bucketSize)) + .collect(Collectors.toList()); + return new SimpleImmutableEntry<>(sampleRanges, stride); + } + + /** + * Gets search results in the sampled time ranges for specified entity. + * + * @param entity specified entity + * @param listener handle search results map: key is time ranges, value is corresponding search results + * @throws IOException if a user gives wrong query input when defining a detector + */ + void getPreviewSamplesInRangesForEntity( + AnomalyDetector detector, + List> sampleRanges, + Entity entity, + ActionListener>, double[][]>> listener + ) throws IOException { + searchFeatureDao + .getColdStartSamplesForPeriods(detector, sampleRanges, entity, true, getSamplesRangesListener(sampleRanges, listener)); + } + + private ActionListener>> getSamplesRangesListener( + List> sampleRanges, + ActionListener>, double[][]>> listener + ) { + return ActionListener.wrap(featureSamples -> { + List> ranges = new ArrayList<>(featureSamples.size()); + List samples = new ArrayList<>(featureSamples.size()); + for (int i = 0; i < featureSamples.size(); i++) { + Entry currentRange = sampleRanges.get(i); + featureSamples.get(i).ifPresent(sample -> { + ranges.add(currentRange); + samples.add(sample); + }); + } + listener.onResponse(new SimpleImmutableEntry<>(ranges, samples.toArray(new double[0][0]))); + }, listener::onFailure); + } + + /** + * Gets search results for the sampled time ranges. + * + * @param listener handle search results map: key is time ranges, value is corresponding search results + * @throws IOException if a user gives wrong query input when defining a detector + */ + void getSamplesForRanges( + AnomalyDetector detector, + List> sampleRanges, + ActionListener>, double[][]>> listener + ) throws IOException { + searchFeatureDao.getFeatureSamplesForPeriods(detector, sampleRanges, getSamplesRangesListener(sampleRanges, listener)); + } + + /** + * Gets time ranges for the data points in the preview range that begins with the first + * sample time range and ends with the last. + * + * @param ranges time ranges of samples + * @param stride the number of data points between samples + * @param shingleSize the size of a shingle + * @return time ranges for all data points + */ + private List> getPreviewRanges(List> ranges, int stride, int shingleSize) { + double[] rangeStarts = ranges.stream().mapToDouble(Entry::getKey).toArray(); + double[] rangeEnds = ranges.stream().mapToDouble(Entry::getValue).toArray(); + double[] previewRangeStarts = imputer.impute(new double[][] { rangeStarts }, stride * (ranges.size() - 1) + 1)[0]; + double[] previewRangeEnds = imputer.impute(new double[][] { rangeEnds }, stride * (ranges.size() - 1) + 1)[0]; + List> previewRanges = IntStream + .range(shingleSize - 1, previewRangeStarts.length) + .mapToObj(i -> new SimpleImmutableEntry<>((long) previewRangeStarts[i], (long) previewRangeEnds[i])) + .collect(Collectors.toList()); + return previewRanges; + } + + /** + * Gets unprocessed and processed features for the data points in the preview range. + * + * To reduce workload on search, the data points within the preview range are interpolated based on + * sample query results. Unprocessed features are interpolated query results. + * Processed features are inputs to models, transformed (such as shingle) from unprocessed features. + * + * @return unprocessed and processed features + */ + private Entry getPreviewFeatures(double[][] samples, int stride, int shingleSize) { + Entry unprocessedAndProcessed = Optional + .of(samples) + .map(m -> transpose(m)) + .map(m -> imputer.impute(m, stride * (samples.length - 1) + 1)) + .map(m -> transpose(m)) + .map(m -> new SimpleImmutableEntry<>(copyOfRange(m, shingleSize - 1, m.length), batchShingle(m, shingleSize))) + .get(); + return unprocessedAndProcessed; + } + + public double[][] transpose(double[][] matrix) { + return createRealMatrix(matrix).transpose().getData(); + } + + private long truncateToMinute(long epochMillis) { + return Instant.ofEpochMilli(epochMillis).truncatedTo(ChronoUnit.MINUTES).toEpochMilli(); + } + + /** + * @return max number of missing points allowed to generate a shingle + */ + private int getMaxMissingPoints(int shingleSize) { + return (int) Math.floor(shingleSize * maxMissingPointsRate); + } + + public int getShingleSize(String detectorId) { + Deque>> shingle = detectorIdsToTimeShingles.get(detectorId); + if (shingle != null) { + return Math.toIntExact(shingle.stream().filter(entry -> entry.getValue().isPresent()).count()); + } else { + return -1; + } + } + + public void getFeatureDataPointsByBatch( + AnomalyDetector detector, + Entity entity, + long startTime, + long endTime, + ActionListener>> listener + ) { + try { + searchFeatureDao.getFeaturesForPeriodByBatch(detector, entity, startTime, endTime, ActionListener.wrap(points -> { + logger.debug("features size: {}", points.size()); + listener.onResponse(points); + }, listener::onFailure)); + } catch (Exception e) { + logger.error("Failed to get features for detector: " + detector.getId()); + listener.onFailure(e); + } + } + + public SinglePointFeatures getShingledFeatureForHistoricalAnalysis( + AnomalyDetector detector, + Deque>> shingle, + Optional dataPoint, + long endTime + ) { + while (shingle.size() >= detector.getShingleSize()) { + shingle.poll(); + } + shingle.add(new AbstractMap.SimpleEntry<>(endTime, dataPoint)); + + return getProcessedFeatures(shingle, detector, endTime); + } + + private SinglePointFeatures getProcessedFeatures( + Deque>> shingle, + AnomalyDetector detector, + long endTime + ) { + int shingleSize = detector.getShingleSize(); + Optional currentPoint = shingle.peekLast().getValue(); + return new SinglePointFeatures( + currentPoint, + Optional + // if current point is not present or current shingle has more missing data points than + // max missing rate, will return null + .ofNullable(currentPoint.isPresent() ? filterAndFill(shingle, endTime, detector) : null) + .map(points -> batchShingle(points, shingleSize)[0]) + ); + } + +} diff --git a/src/main/java/org/opensearch/ad/feature/Features.java-e b/src/main/java/org/opensearch/ad/feature/Features.java-e new file mode 100644 index 000000000..de347b78f --- /dev/null +++ b/src/main/java/org/opensearch/ad/feature/Features.java-e @@ -0,0 +1,84 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.feature; + +import java.util.Arrays; +import java.util.List; +import java.util.Map.Entry; +import java.util.Objects; + +/** + * Data object for features internally used with ML. + */ +public class Features { + + private final List> timeRanges; + private final double[][] unprocessedFeatures; + private final double[][] processedFeatures; + + /** + * Constructor with all arguments. + * + * @param timeRanges the time ranges of feature data points + * @param unprocessedFeatures unprocessed feature values (such as from aggregates from search) + * @param processedFeatures processed feature values (such as shingle) + */ + public Features(List> timeRanges, double[][] unprocessedFeatures, double[][] processedFeatures) { + this.timeRanges = timeRanges; + this.unprocessedFeatures = unprocessedFeatures; + this.processedFeatures = processedFeatures; + } + + /** + * Returns the time ranges of feature data points. + * + * @return list of pairs of start and end in epoch milliseconds + */ + public List> getTimeRanges() { + return timeRanges; + } + + /** + * Returns unprocessed features (such as from aggregates from search). + * + * @return unprocessed features of data points + */ + public double[][] getUnprocessedFeatures() { + return unprocessedFeatures; + } + + /** + * Returns processed features (such as shingle). + * + * @return processed features of data points + */ + public double[][] getProcessedFeatures() { + return processedFeatures; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + Features that = (Features) o; + return Objects.equals(this.timeRanges, that.timeRanges) + && Arrays.deepEquals(this.unprocessedFeatures, that.unprocessedFeatures) + && Arrays.deepEquals(this.processedFeatures, that.processedFeatures); + } + + @Override + public int hashCode() { + return Objects.hash(timeRanges, unprocessedFeatures, processedFeatures); + } +} diff --git a/src/main/java/org/opensearch/ad/feature/SearchFeatureDao.java-e b/src/main/java/org/opensearch/ad/feature/SearchFeatureDao.java-e new file mode 100644 index 000000000..557e98fd7 --- /dev/null +++ b/src/main/java/org/opensearch/ad/feature/SearchFeatureDao.java-e @@ -0,0 +1,931 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.feature; + +import static org.apache.commons.math3.linear.MatrixUtils.createRealMatrix; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.PAGE_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.PREVIEW_TIMEOUT_IN_MILLIS; +import static org.opensearch.timeseries.util.ParseUtils.batchFeatureQuery; + +import java.io.IOException; +import java.time.Clock; +import java.time.ZonedDateTime; +import java.util.AbstractMap.SimpleEntry; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.stream.Collectors; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.PipelineAggregatorBuilders; +import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; +import org.opensearch.search.aggregations.bucket.composite.InternalComposite; +import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; +import org.opensearch.search.aggregations.bucket.range.InternalDateRange; +import org.opensearch.search.aggregations.bucket.range.InternalDateRange.Bucket; +import org.opensearch.search.aggregations.bucket.terms.Terms; +import org.opensearch.search.aggregations.metrics.Min; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.FieldSortBuilder; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.dataprocessor.Imputer; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.util.ParseUtils; + +/** + * DAO for features from search. + */ +public class SearchFeatureDao extends AbstractRetriever { + + protected static final String AGG_NAME_MIN = "min_timefield"; + protected static final String AGG_NAME_TOP = "top_agg"; + + private static final Logger logger = LogManager.getLogger(SearchFeatureDao.class); + + // Dependencies + private final Client client; + private final NamedXContentRegistry xContent; + private final Imputer imputer; + private final SecurityClientUtil clientUtil; + private volatile int maxEntitiesForPreview; + private volatile int pageSize; + private final int minimumDocCountForPreview; + private long previewTimeoutInMilliseconds; + private Clock clock; + + // used for testing as we can mock clock + public SearchFeatureDao( + Client client, + NamedXContentRegistry xContent, + Imputer imputer, + SecurityClientUtil clientUtil, + Settings settings, + ClusterService clusterService, + int minimumDocCount, + Clock clock, + int maxEntitiesForPreview, + int pageSize, + long previewTimeoutInMilliseconds + ) { + this.client = client; + this.xContent = xContent; + this.imputer = imputer; + this.clientUtil = clientUtil; + this.maxEntitiesForPreview = maxEntitiesForPreview; + + this.pageSize = pageSize; + + if (clusterService != null) { + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_ENTITIES_FOR_PREVIEW, it -> this.maxEntitiesForPreview = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(PAGE_SIZE, it -> this.pageSize = it); + } + this.minimumDocCountForPreview = minimumDocCount; + this.previewTimeoutInMilliseconds = previewTimeoutInMilliseconds; + this.clock = clock; + } + + /** + * Constructor injection. + * + * @param client ES client for queries + * @param xContent ES XContentRegistry + * @param imputer imputer for missing values + * @param clientUtil utility for ES client + * @param settings ES settings + * @param clusterService ES ClusterService + * @param minimumDocCount minimum doc count required for an entity; used to + * make sure an entity has enough samples for preview + */ + public SearchFeatureDao( + Client client, + NamedXContentRegistry xContent, + Imputer imputer, + SecurityClientUtil clientUtil, + Settings settings, + ClusterService clusterService, + int minimumDocCount + ) { + this( + client, + xContent, + imputer, + clientUtil, + settings, + clusterService, + minimumDocCount, + Clock.systemUTC(), + MAX_ENTITIES_FOR_PREVIEW.get(settings), + PAGE_SIZE.get(settings), + PREVIEW_TIMEOUT_IN_MILLIS + ); + } + + /** + * Returns to listener the epoch time of the latset data under the detector. + * + * @param detector info about the data + * @param listener onResponse is called with the epoch time of the latset data under the detector + */ + public void getLatestDataTime(AnomalyDetector detector, ActionListener> listener) { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .aggregation(AggregationBuilders.max(CommonName.AGG_NAME_MAX_TIME).field(detector.getTimeField())) + .size(0); + SearchRequest searchRequest = new SearchRequest().indices(detector.getIndices().toArray(new String[0])).source(searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener + .wrap(response -> listener.onResponse(ParseUtils.getLatestDataTime(response)), listener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + detector.getId(), + client, + searchResponseListener + ); + } + + /** + * Get list of entities with high count in descending order within specified time range + * @param detector detector config + * @param startTime start time of time range + * @param endTime end time of time range + * @param listener listener to return back the entities + */ + public void getHighestCountEntities(AnomalyDetector detector, long startTime, long endTime, ActionListener> listener) { + getHighestCountEntities(detector, startTime, endTime, maxEntitiesForPreview, minimumDocCountForPreview, pageSize, listener); + } + + /** + * Get list of entities with high count in descending order within specified time range + * @param detector detector config + * @param startTime start time of time range + * @param endTime end time of time range + * @param maxEntitiesSize max top entities + * @param minimumDocCount minimum doc count for top entities + * @param pageSize page size when query multi-category HC detector's top entities + * @param listener listener to return back the entities + */ + public void getHighestCountEntities( + AnomalyDetector detector, + long startTime, + long endTime, + int maxEntitiesSize, + int minimumDocCount, + int pageSize, + ActionListener> listener + ) { + if (!detector.isHighCardinality()) { + listener.onResponse(null); + return; + } + + RangeQueryBuilder rangeQuery = new RangeQueryBuilder(detector.getTimeField()) + .from(startTime) + .to(endTime) + .format("epoch_millis") + .includeLower(true) + .includeUpper(false); + + BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery().filter(rangeQuery).filter(detector.getFilterQuery()); + AggregationBuilder bucketAggs = null; + + if (detector.getCategoryFields().size() == 1) { + bucketAggs = AggregationBuilders.terms(AGG_NAME_TOP).size(maxEntitiesSize).field(detector.getCategoryFields().get(0)); + } else { + /* + * We don't have an efficient solution for terms aggregation on multiple fields. + * Terms aggregation does not support collecting terms from multiple fields in the same document. + * We have to work around the limitation by using a script to retrieve terms from multiple fields. + * The workaround disables the global ordinals optimization and thus causes a markedly longer + * slowdown. This is because scripting is tugging on memory and has to iterate through + * all of the documents at least once to create run-time fields. + * + * We evaluated composite and terms aggregation using a generated data set with one + * million entities. Each entity has two documents. Composite aggregation finishes + * around 40 seconds. Terms aggregation performs differently on different clusters. + * On a 3 data node cluster, terms aggregation does not finish running within 2 hours + * on a 5 primary shard index. On a 15 data node cluster, terms aggregation needs 217 seconds + * on a 15 primary shard index. On a 30 data node cluster, terms aggregation needs 47 seconds + * on a 30 primary shard index. + * + * Here we work around the problem using composite aggregation. Composite aggregation cannot + * give top entities without collecting all aggregated results. Paginated results are returned + * in the natural order of composite keys. This is fine for Preview API. Preview API needs the + * top entities to make sure there is enough data for training and showing the results. We + * can paginate entities and filter out entities that do not have enough docs (e.g., 256 docs). + * As long as we have collected the desired number of entities (e.g., 5 entities), we can stop + * pagination. + * + * Example composite query: + * { + * "size": 0, + * "query": { + * "bool": { + * "filter": [{ + * "range": { + * "@timestamp": { + * "from": 1626118340000, + * "to": 1626294912000, + * "include_lower": true, + * "include_upper": false, + * "format": "epoch_millis", + * "boost": 1.0 + * } + * } + * }, { + * "match_all": { + * "boost": 1.0 + * } + * }], + * "adjust_pure_negative": true, + * "boost": 1.0 + * } + * }, + * "track_total_hits": -1, + * "aggregations": { + * "top_agg": { + * "composite": { + * "size": 1, + * "sources": [{ + * "service": { + * "terms": { + * "field": "service", + * "missing_bucket": false, + * "order": "asc" + * } + * } + * }, { + * "host": { + * "terms": { + * "field": "host", + * "missing_bucket": false, + * "order": "asc" + * } + * } + * }] + * }, + * "aggregations": { + * "bucketSort": { + * "bucket_sort": { + * "sort": [{ + * "_count": { + * "order": "desc" + * } + * }], + * "from": 0, + * "size": 5, + * "gap_policy": "SKIP" + * } + * } + * } + * } + * } + * } + * + */ + bucketAggs = AggregationBuilders + .composite( + AGG_NAME_TOP, + detector.getCategoryFields().stream().map(f -> new TermsValuesSourceBuilder(f).field(f)).collect(Collectors.toList()) + ) + .size(pageSize) + .subAggregation( + PipelineAggregatorBuilders + .bucketSort("bucketSort", Arrays.asList(new FieldSortBuilder("_count").order(SortOrder.DESC))) + .size(maxEntitiesSize) + ); + } + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(boolQueryBuilder) + .aggregation(bucketAggs) + .trackTotalHits(false) + .size(0); + SearchRequest searchRequest = new SearchRequest().indices(detector.getIndices().toArray(new String[0])).source(searchSourceBuilder); + final ActionListener searchResponseListener = new TopEntitiesListener( + listener, + detector, + searchSourceBuilder, + // TODO: tune timeout for historical analysis based on performance test result + clock.millis() + previewTimeoutInMilliseconds, + maxEntitiesSize, + minimumDocCount + ); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + detector.getId(), + client, + searchResponseListener + ); + } + + class TopEntitiesListener implements ActionListener { + private ActionListener> listener; + private AnomalyDetector detector; + private List topEntities; + private SearchSourceBuilder searchSourceBuilder; + private long expirationEpochMs; + private long minimumDocCount; + private int maxEntitiesSize; + + TopEntitiesListener( + ActionListener> listener, + AnomalyDetector detector, + SearchSourceBuilder searchSourceBuilder, + long expirationEpochMs, + int maxEntitiesSize, + int minimumDocCount + ) { + this.listener = listener; + this.detector = detector; + this.topEntities = new ArrayList<>(); + this.searchSourceBuilder = searchSourceBuilder; + this.expirationEpochMs = expirationEpochMs; + this.maxEntitiesSize = maxEntitiesSize; + this.minimumDocCount = minimumDocCount; + } + + @Override + public void onResponse(SearchResponse response) { + try { + Aggregations aggs = response.getAggregations(); + if (aggs == null) { + // This would indicate some bug or some opensearch core changes that we are not aware of (we don't keep up-to-date with + // the large amounts of changes there). For example, they may change to if there are results return it; otherwise return + // null instead of an empty Aggregations as they currently do. + logger.warn("Unexpected null aggregation."); + listener.onResponse(topEntities); + return; + } + + Aggregation aggrResult = aggs.get(AGG_NAME_TOP); + if (aggrResult == null) { + listener.onFailure(new IllegalArgumentException("Fail to find valid aggregation result")); + return; + } + + if (detector.getCategoryFields().size() == 1) { + topEntities = ((Terms) aggrResult) + .getBuckets() + .stream() + .map(bucket -> bucket.getKeyAsString()) + .collect(Collectors.toList()) + .stream() + .map(entityValue -> Entity.createSingleAttributeEntity(detector.getCategoryFields().get(0), entityValue)) + .collect(Collectors.toList()); + listener.onResponse(topEntities); + } else { + CompositeAggregation compositeAgg = (CompositeAggregation) aggrResult; + List pageResults = compositeAgg + .getBuckets() + .stream() + .filter(bucket -> bucket.getDocCount() >= minimumDocCount) + .map(bucket -> Entity.createEntityByReordering(bucket.getKey())) + .collect(Collectors.toList()); + // we only need at most maxEntitiesForPreview + int amountToWrite = maxEntitiesSize - topEntities.size(); + for (int i = 0; i < amountToWrite && i < pageResults.size(); i++) { + topEntities.add(pageResults.get(i)); + } + Map afterKey = compositeAgg.afterKey(); + if (topEntities.size() >= maxEntitiesSize || afterKey == null) { + listener.onResponse(topEntities); + } else if (expirationEpochMs < clock.millis()) { + if (topEntities.isEmpty()) { + listener.onFailure(new TimeSeriesException("timeout to get preview results. Please retry later.")); + } else { + logger.info("timeout to get preview results. Send whatever we have."); + listener.onResponse(topEntities); + } + } else { + updateSourceAfterKey(afterKey, searchSourceBuilder); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + new SearchRequest().indices(detector.getIndices().toArray(new String[0])).source(searchSourceBuilder), + client::search, + detector.getId(), + client, + this + ); + } + } + } catch (Exception e) { + onFailure(e); + } + } + + @Override + public void onFailure(Exception e) { + logger.error("Fail to paginate", e); + listener.onFailure(e); + } + } + + /** + * Get the entity's earliest timestamps + * @param detector detector config + * @param entity the entity's information + * @param listener listener to return back the requested timestamps + */ + public void getEntityMinDataTime(AnomalyDetector detector, Entity entity, ActionListener> listener) { + BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery(); + + for (TermQueryBuilder term : entity.getTermQueryBuilders()) { + internalFilterQuery.filter(term); + } + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(internalFilterQuery) + .aggregation(AggregationBuilders.min(AGG_NAME_MIN).field(detector.getTimeField())) + .trackTotalHits(false) + .size(0); + SearchRequest searchRequest = new SearchRequest().indices(detector.getIndices().toArray(new String[0])).source(searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener + .wrap(response -> { listener.onResponse(parseMinDataTime(response)); }, listener::onFailure); + // inject user role while searching. + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + detector.getId(), + client, + searchResponseListener + ); + } + + private Optional parseMinDataTime(SearchResponse searchResponse) { + Optional> mapOptional = Optional + .ofNullable(searchResponse) + .map(SearchResponse::getAggregations) + .map(aggs -> aggs.asMap()); + + return mapOptional.map(map -> (Min) map.get(AGG_NAME_MIN)).map(agg -> (long) agg.getValue()); + } + + /** + * Returns to listener features for the given time period. + * + * @param detector info about indices, feature query + * @param startTime epoch milliseconds at the beginning of the period + * @param endTime epoch milliseconds at the end of the period + * @param listener onResponse is called with features for the given time period. + */ + public void getFeaturesForPeriod(AnomalyDetector detector, long startTime, long endTime, ActionListener> listener) { + SearchRequest searchRequest = createFeatureSearchRequest(detector, startTime, endTime, Optional.empty()); + final ActionListener searchResponseListener = ActionListener + .wrap(response -> listener.onResponse(parseResponse(response, detector.getEnabledFeatureIds())), listener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + detector.getId(), + client, + searchResponseListener + ); + } + + public void getFeaturesForPeriodByBatch( + AnomalyDetector detector, + Entity entity, + long startTime, + long endTime, + ActionListener>> listener + ) throws IOException { + SearchSourceBuilder searchSourceBuilder = batchFeatureQuery(detector, entity, startTime, endTime, xContent); + logger.debug("Batch query for detector {}: {} ", detector.getId(), searchSourceBuilder); + + SearchRequest searchRequest = new SearchRequest(detector.getIndices().toArray(new String[0])).source(searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener + .wrap( + response -> { listener.onResponse(parseBucketAggregationResponse(response, detector.getEnabledFeatureIds())); }, + listener::onFailure + ); + // inject user role while searching. + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + detector.getId(), + client, + searchResponseListener + ); + } + + private Map> parseBucketAggregationResponse(SearchResponse response, List featureIds) { + Map> dataPoints = new HashMap<>(); + List aggregations = response.getAggregations().asList(); + logger.debug("Feature aggregation result size {}", aggregations.size()); + for (Aggregation agg : aggregations) { + List buckets = ((InternalComposite) agg).getBuckets(); + buckets.forEach(bucket -> { + Optional featureData = parseAggregations(Optional.ofNullable(bucket.getAggregations()), featureIds); + dataPoints.put((Long) bucket.getKey().get(CommonName.DATE_HISTOGRAM), featureData); + }); + } + return dataPoints; + } + + public Optional parseResponse(SearchResponse response, List featureIds) { + return parseAggregations(Optional.ofNullable(response).map(resp -> resp.getAggregations()), featureIds); + } + + /** + * Gets samples of features for the time ranges. + * + * Sampled features are not true features. They are intended to be approximate results produced at low costs. + * + * @param detector info about the indices, documents, feature query + * @param ranges list of time ranges + * @param listener handle approximate features for the time ranges + * @throws IOException if a user gives wrong query input when defining a detector + */ + public void getFeatureSamplesForPeriods( + AnomalyDetector detector, + List> ranges, + ActionListener>> listener + ) throws IOException { + SearchRequest request = createPreviewSearchRequest(detector, ranges); + final ActionListener searchResponseListener = ActionListener.wrap(response -> { + Aggregations aggs = response.getAggregations(); + if (aggs == null) { + listener.onResponse(Collections.emptyList()); + return; + } + + listener + .onResponse( + aggs + .asList() + .stream() + .filter(InternalDateRange.class::isInstance) + .flatMap(agg -> ((InternalDateRange) agg).getBuckets().stream()) + .map(bucket -> parseBucket(bucket, detector.getEnabledFeatureIds())) + .collect(Collectors.toList()) + ); + }, listener::onFailure); + // inject user role while searching + clientUtil + .asyncRequestWithInjectedSecurity( + request, + client::search, + detector.getId(), + client, + searchResponseListener + ); + } + + /** + * Returns to listener features for sampled periods. + * + * Sampling starts with the latest period and goes backwards in time until there are up to {@code maxSamples} samples. + * If the initial stride {@code maxStride} results into a low count of samples, the implementation + * may attempt with (exponentially) reduced strides and interpolate missing points. + * + * @param detector info about indices, documents, feature query + * @param maxSamples the maximum number of samples to return + * @param maxStride the maximum number of periods between samples + * @param endTime the end time of the latest period + * @param listener onResponse is called with sampled features and stride between points, or empty for no data + */ + public void getFeaturesForSampledPeriods( + AnomalyDetector detector, + int maxSamples, + int maxStride, + long endTime, + ActionListener>> listener + ) { + Map cache = new HashMap<>(); + logger.info(String.format(Locale.ROOT, "Getting features for detector %s ending at %d", detector.getId(), endTime)); + getFeatureSamplesWithCache(detector, maxSamples, maxStride, endTime, cache, maxStride, listener); + } + + private void getFeatureSamplesWithCache( + AnomalyDetector detector, + int maxSamples, + int maxStride, + long endTime, + Map cache, + int currentStride, + ActionListener>> listener + ) { + getFeatureSamplesForStride( + detector, + maxSamples, + maxStride, + currentStride, + endTime, + cache, + ActionListener + .wrap( + features -> processFeatureSamplesForStride( + features, + detector, + maxSamples, + maxStride, + currentStride, + endTime, + cache, + listener + ), + listener::onFailure + ) + ); + } + + private void processFeatureSamplesForStride( + Optional features, + AnomalyDetector detector, + int maxSamples, + int maxStride, + int currentStride, + long endTime, + Map cache, + ActionListener>> listener + ) { + if (!features.isPresent()) { + logger + .info( + String + .format( + Locale.ROOT, + "Get features for detector %s finishes without any features present, current stride %d", + detector.getId(), + currentStride + ) + ); + listener.onResponse(Optional.empty()); + } else if (features.get().length > maxSamples / 2 || currentStride == 1) { + logger + .info( + String + .format( + Locale.ROOT, + "Get features for detector %s finishes with %d samples, current stride %d", + detector.getId(), + features.get().length, + currentStride + ) + ); + listener.onResponse(Optional.of(new SimpleEntry<>(features.get(), currentStride))); + } else { + getFeatureSamplesWithCache(detector, maxSamples, maxStride, endTime, cache, currentStride / 2, listener); + } + } + + private void getFeatureSamplesForStride( + AnomalyDetector detector, + int maxSamples, + int maxStride, + int currentStride, + long endTime, + Map cache, + ActionListener> listener + ) { + ArrayDeque sampledFeatures = new ArrayDeque<>(maxSamples); + boolean isInterpolatable = currentStride < maxStride; + long span = ((IntervalTimeConfiguration) detector.getInterval()).toDuration().toMillis(); + sampleForIteration(detector, cache, maxSamples, endTime, span, currentStride, sampledFeatures, isInterpolatable, 0, listener); + } + + private void sampleForIteration( + AnomalyDetector detector, + Map cache, + int maxSamples, + long endTime, + long span, + int stride, + ArrayDeque sampledFeatures, + boolean isInterpolatable, + int iteration, + ActionListener> listener + ) { + if (iteration < maxSamples) { + long end = endTime - span * stride * iteration; + if (cache.containsKey(end)) { + sampledFeatures.addFirst(cache.get(end)); + sampleForIteration( + detector, + cache, + maxSamples, + endTime, + span, + stride, + sampledFeatures, + isInterpolatable, + iteration + 1, + listener + ); + } else { + getFeaturesForPeriod(detector, end - span, end, ActionListener.wrap(features -> { + if (features.isPresent()) { + cache.put(end, features.get()); + sampledFeatures.addFirst(features.get()); + sampleForIteration( + detector, + cache, + maxSamples, + endTime, + span, + stride, + sampledFeatures, + isInterpolatable, + iteration + 1, + listener + ); + } else if (isInterpolatable) { + Optional previous = Optional.ofNullable(cache.get(end - span * stride)); + Optional next = Optional.ofNullable(cache.get(end + span * stride)); + if (previous.isPresent() && next.isPresent()) { + double[] interpolants = getInterpolants(previous.get(), next.get()); + cache.put(end, interpolants); + sampledFeatures.addFirst(interpolants); + sampleForIteration( + detector, + cache, + maxSamples, + endTime, + span, + stride, + sampledFeatures, + isInterpolatable, + iteration + 1, + listener + ); + } else { + listener.onResponse(toMatrix(sampledFeatures)); + } + } else { + listener.onResponse(toMatrix(sampledFeatures)); + } + }, listener::onFailure)); + } + } else { + listener.onResponse(toMatrix(sampledFeatures)); + } + } + + private Optional toMatrix(ArrayDeque sampledFeatures) { + Optional samples; + if (sampledFeatures.isEmpty()) { + samples = Optional.empty(); + } else { + samples = Optional.of(sampledFeatures.toArray(new double[0][0])); + } + return samples; + } + + private double[] getInterpolants(double[] previous, double[] next) { + return transpose(imputer.impute(transpose(new double[][] { previous, next }), 3))[1]; + } + + private double[][] transpose(double[][] matrix) { + return createRealMatrix(matrix).transpose().getData(); + } + + private SearchRequest createFeatureSearchRequest(AnomalyDetector detector, long startTime, long endTime, Optional preference) { + // TODO: FeatureQuery field is planned to be removed and search request creation will migrate to new api. + try { + SearchSourceBuilder searchSourceBuilder = ParseUtils.generateInternalFeatureQuery(detector, startTime, endTime, xContent); + return new SearchRequest(detector.getIndices().toArray(new String[0]), searchSourceBuilder).preference(preference.orElse(null)); + } catch (IOException e) { + logger.warn("Failed to create feature search request for " + detector.getId() + " from " + startTime + " to " + endTime, e); + throw new IllegalStateException(e); + } + } + + private SearchRequest createPreviewSearchRequest(AnomalyDetector detector, List> ranges) throws IOException { + try { + SearchSourceBuilder searchSourceBuilder = ParseUtils.generatePreviewQuery(detector, ranges, xContent); + return new SearchRequest(detector.getIndices().toArray(new String[0]), searchSourceBuilder); + } catch (IOException e) { + logger.warn("Failed to create feature search request for " + detector.getId() + " for preview", e); + throw e; + } + } + + public void getColdStartSamplesForPeriods( + AnomalyDetector detector, + List> ranges, + Entity entity, + boolean includesEmptyBucket, + ActionListener>> listener + ) { + SearchRequest request = createColdStartFeatureSearchRequest(detector, ranges, entity); + final ActionListener searchResponseListener = ActionListener.wrap(response -> { + Aggregations aggs = response.getAggregations(); + if (aggs == null) { + listener.onResponse(Collections.emptyList()); + return; + } + + long docCountThreshold = includesEmptyBucket ? -1 : 0; + + // Extract buckets and order by from_as_string. Currently by default it is ascending. Better not to assume it. + // Example responses from date range bucket aggregation: + // "aggregations":{"date_range":{"buckets":[{"key":"1598865166000-1598865226000","from":1.598865166E12," + // from_as_string":"1598865166000","to":1.598865226E12,"to_as_string":"1598865226000","doc_count":3, + // "deny_max":{"value":154.0}},{"key":"1598869006000-1598869066000","from":1.598869006E12, + // "from_as_string":"1598869006000","to":1.598869066E12,"to_as_string":"1598869066000","doc_count":3, + // "deny_max":{"value":141.0}}, + // We don't want to use default 0 for sum/count aggregation as it might cause false positives during scoring. + // Terms aggregation only returns non-zero count values. If we use a lot of 0s during cold start, + // we will see alarming very easily. + listener + .onResponse( + aggs + .asList() + .stream() + .filter(InternalDateRange.class::isInstance) + .flatMap(agg -> ((InternalDateRange) agg).getBuckets().stream()) + .filter(bucket -> bucket.getFrom() != null && bucket.getFrom() instanceof ZonedDateTime) + .filter(bucket -> bucket.getDocCount() > docCountThreshold) + .sorted(Comparator.comparing((Bucket bucket) -> (ZonedDateTime) bucket.getFrom())) + .map(bucket -> parseBucket(bucket, detector.getEnabledFeatureIds())) + .collect(Collectors.toList()) + ); + }, listener::onFailure); + + // inject user role while searching. + clientUtil + .asyncRequestWithInjectedSecurity( + request, + client::search, + detector.getId(), + client, + searchResponseListener + ); + } + + private SearchRequest createColdStartFeatureSearchRequest(AnomalyDetector detector, List> ranges, Entity entity) { + try { + SearchSourceBuilder searchSourceBuilder = ParseUtils.generateEntityColdStartQuery(detector, ranges, entity, xContent); + return new SearchRequest(detector.getIndices().toArray(new String[0]), searchSourceBuilder); + } catch (IOException e) { + logger + .warn( + "Failed to create cold start feature search request for " + + detector.getId() + + " from " + + ranges.get(0).getKey() + + " to " + + ranges.get(ranges.size() - 1).getKey(), + e + ); + throw new IllegalStateException(e); + } + } + + @Override + public Optional parseBucket(MultiBucketsAggregation.Bucket bucket, List featureIds) { + return parseAggregations(Optional.ofNullable(bucket).map(b -> b.getAggregations()), featureIds); + } +} diff --git a/src/main/java/org/opensearch/ad/feature/SinglePointFeatures.java-e b/src/main/java/org/opensearch/ad/feature/SinglePointFeatures.java-e new file mode 100644 index 000000000..cbd7ef78b --- /dev/null +++ b/src/main/java/org/opensearch/ad/feature/SinglePointFeatures.java-e @@ -0,0 +1,54 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.feature; + +import java.util.Optional; + +/** + * Features for one data point. + * + * A data point consists of unprocessed features (raw search results) and corresponding processed ML features. + */ +public class SinglePointFeatures { + + private final Optional unprocessedFeatures; + private final Optional processedFeatures; + + /** + * Constructor. + * + * @param unprocessedFeatures unprocessed features + * @param processedFeatures processed features + */ + public SinglePointFeatures(Optional unprocessedFeatures, Optional processedFeatures) { + this.unprocessedFeatures = unprocessedFeatures; + this.processedFeatures = processedFeatures; + } + + /** + * Returns unprocessed features. + * + * @return unprocessed features + */ + public Optional getUnprocessedFeatures() { + return this.unprocessedFeatures; + } + + /** + * Returns processed features. + * + * @return processed features + */ + public Optional getProcessedFeatures() { + return this.processedFeatures; + } +} diff --git a/src/main/java/org/opensearch/ad/indices/ADIndex.java-e b/src/main/java/org/opensearch/ad/indices/ADIndex.java-e new file mode 100644 index 000000000..b345ef33e --- /dev/null +++ b/src/main/java/org/opensearch/ad/indices/ADIndex.java-e @@ -0,0 +1,73 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.indices; + +import java.util.function.Supplier; + +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.ThrowingSupplierWrapper; +import org.opensearch.timeseries.indices.TimeSeriesIndex; + +/** + * Represent an AD index + * + */ +public enum ADIndex implements TimeSeriesIndex { + + // throw RuntimeException since we don't know how to handle the case when the mapping reading throws IOException + RESULT( + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + true, + ThrowingSupplierWrapper.throwingSupplierWrapper(ADIndexManagement::getResultMappings) + ), + CONFIG(CommonName.CONFIG_INDEX, false, ThrowingSupplierWrapper.throwingSupplierWrapper(ADIndexManagement::getConfigMappings)), + JOB(CommonName.JOB_INDEX, false, ThrowingSupplierWrapper.throwingSupplierWrapper(ADIndexManagement::getJobMappings)), + CHECKPOINT( + ADCommonName.CHECKPOINT_INDEX_NAME, + false, + ThrowingSupplierWrapper.throwingSupplierWrapper(ADIndexManagement::getCheckpointMappings) + ), + STATE(ADCommonName.DETECTION_STATE_INDEX, false, ThrowingSupplierWrapper.throwingSupplierWrapper(ADIndexManagement::getStateMappings)); + + private final String indexName; + // whether we use an alias for the index + private final boolean alias; + private final String mapping; + + ADIndex(String name, boolean alias, Supplier mappingSupplier) { + this.indexName = name; + this.alias = alias; + this.mapping = mappingSupplier.get(); + } + + @Override + public String getIndexName() { + return indexName; + } + + @Override + public boolean isAlias() { + return alias; + } + + @Override + public String getMapping() { + return mapping; + } + + @Override + public boolean isJobIndex() { + return CommonName.JOB_INDEX.equals(indexName); + } + +} diff --git a/src/main/java/org/opensearch/ad/indices/ADIndexManagement.java-e b/src/main/java/org/opensearch/ad/indices/ADIndexManagement.java-e new file mode 100644 index 000000000..ef534621d --- /dev/null +++ b/src/main/java/org/opensearch/ad/indices/ADIndexManagement.java-e @@ -0,0 +1,284 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.indices; + +import static org.opensearch.ad.constant.ADCommonName.DUMMY_AD_RESULT_ID; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_PRIMARY_SHARDS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.ANOMALY_DETECTION_STATE_INDEX_MAPPING_FILE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.ANOMALY_RESULTS_INDEX_MAPPING_FILE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.CHECKPOINT_INDEX_MAPPING_FILE; +import static org.opensearch.cluster.metadata.IndexMetadata.SETTING_REPLICATION_TYPE; +import static org.opensearch.indices.replication.common.ReplicationType.DOCUMENT; + +import java.io.IOException; +import java.util.EnumMap; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +/** + * This class provides utility methods for various anomaly detection indices. + */ +public class ADIndexManagement extends IndexManagement { + private static final Logger logger = LogManager.getLogger(ADIndexManagement.class); + + // The index name pattern to query all the AD result history indices + public static final String AD_RESULT_HISTORY_INDEX_PATTERN = "<.opendistro-anomaly-results-history-{now/d}-1>"; + + // The index name pattern to query all AD result, history and current AD result + public static final String ALL_AD_RESULTS_INDEX_PATTERN = ".opendistro-anomaly-results*"; + + /** + * Constructor function + * + * @param client OS client supports administrative actions + * @param clusterService OS cluster service + * @param threadPool OS thread pool + * @param settings OS cluster setting + * @param nodeFilter Used to filter eligible nodes to host AD indices + * @param maxUpdateRunningTimes max number of retries to update index mapping and setting + * @throws IOException when failing to get mapping file + */ + public ADIndexManagement( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + Settings settings, + DiscoveryNodeFilterer nodeFilter, + int maxUpdateRunningTimes + ) + throws IOException { + super( + client, + clusterService, + threadPool, + settings, + nodeFilter, + maxUpdateRunningTimes, + ADIndex.class, + AD_MAX_PRIMARY_SHARDS.get(settings), + AD_RESULT_HISTORY_ROLLOVER_PERIOD.get(settings), + AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD.get(settings), + AD_RESULT_HISTORY_RETENTION_PERIOD.get(settings), + ADIndex.RESULT.getMapping() + ); + this.clusterService.addLocalNodeClusterManagerListener(this); + + this.indexStates = new EnumMap(ADIndex.class); + + this.clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD, it -> historyMaxDocs = it); + + this.clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_RESULT_HISTORY_ROLLOVER_PERIOD, it -> { + historyRolloverPeriod = it; + rescheduleRollover(); + }); + this.clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(AD_RESULT_HISTORY_RETENTION_PERIOD, it -> { historyRetentionPeriod = it; }); + + this.clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_MAX_PRIMARY_SHARDS, it -> maxPrimaryShards = it); + } + + /** + * Get anomaly result index mapping json content. + * + * @return anomaly result index mapping + * @throws IOException IOException if mapping file can't be read correctly + */ + public static String getResultMappings() throws IOException { + return getMappings(ANOMALY_RESULTS_INDEX_MAPPING_FILE); + } + + /** + * Get anomaly detector state index mapping json content. + * + * @return anomaly detector state index mapping + * @throws IOException IOException if mapping file can't be read correctly + */ + public static String getStateMappings() throws IOException { + String detectionStateMappings = getMappings(ANOMALY_DETECTION_STATE_INDEX_MAPPING_FILE); + String detectorIndexMappings = getConfigMappings(); + detectorIndexMappings = detectorIndexMappings + .substring(detectorIndexMappings.indexOf("\"properties\""), detectorIndexMappings.lastIndexOf("}")); + return detectionStateMappings.replace("DETECTOR_INDEX_MAPPING_PLACE_HOLDER", detectorIndexMappings); + } + + /** + * Get checkpoint index mapping json content. + * + * @return checkpoint index mapping + * @throws IOException IOException if mapping file can't be read correctly + */ + public static String getCheckpointMappings() throws IOException { + return getMappings(CHECKPOINT_INDEX_MAPPING_FILE); + } + + /** + * anomaly result index exist or not. + * + * @return true if anomaly result index exists + */ + @Override + public boolean doesDefaultResultIndexExist() { + return doesAliasExist(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS); + } + + /** + * Anomaly state index exist or not. + * + * @return true if anomaly state index exists + */ + @Override + public boolean doesStateIndexExist() { + return doesIndexExist(ADCommonName.DETECTION_STATE_INDEX); + } + + /** + * Checkpoint index exist or not. + * + * @return true if checkpoint index exists + */ + @Override + public boolean doesCheckpointIndexExist() { + return doesIndexExist(ADCommonName.CHECKPOINT_INDEX_NAME); + } + + /** + * Create anomaly result index without checking exist or not. + * + * @param actionListener action called after create index + */ + @Override + public void initDefaultResultIndexDirectly(ActionListener actionListener) { + initResultIndexDirectly( + AD_RESULT_HISTORY_INDEX_PATTERN, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + true, + AD_RESULT_HISTORY_INDEX_PATTERN, + ADIndex.RESULT, + actionListener + ); + } + + /** + * Create the state index. + * + * @param actionListener action called after create index + */ + @Override + public void initStateIndex(ActionListener actionListener) { + try { + // AD indices need RAW (e.g., we want users to be able to consume AD results as soon as possible and send out an alert if + // anomalies found). + Settings replicationSettings = Settings.builder().put(SETTING_REPLICATION_TYPE, DOCUMENT.name()).build(); + CreateIndexRequest request = new CreateIndexRequest(ADCommonName.DETECTION_STATE_INDEX, replicationSettings) + .mapping(getStateMappings(), XContentType.JSON) + .settings(settings); + adminClient.indices().create(request, markMappingUpToDate(ADIndex.STATE, actionListener)); + } catch (IOException e) { + logger.error("Fail to init AD detection state index", e); + actionListener.onFailure(e); + } + } + + /** + * Create the checkpoint index. + * + * @param actionListener action called after create index + * @throws EndRunException EndRunException due to failure to get mapping + */ + @Override + public void initCheckpointIndex(ActionListener actionListener) { + String mapping; + try { + mapping = getCheckpointMappings(); + } catch (IOException e) { + throw new EndRunException("", "Cannot find checkpoint mapping file", true); + } + // AD indices need RAW (e.g., we want users to be able to consume AD results as soon as possible and send out an alert if anomalies + // found). + Settings replicationSettings = Settings.builder().put(SETTING_REPLICATION_TYPE, DOCUMENT.name()).build(); + CreateIndexRequest request = new CreateIndexRequest(ADCommonName.CHECKPOINT_INDEX_NAME, replicationSettings) + .mapping(mapping, XContentType.JSON); + choosePrimaryShards(request, true); + adminClient.indices().create(request, markMappingUpToDate(ADIndex.CHECKPOINT, actionListener)); + } + + @Override + protected void rolloverAndDeleteHistoryIndex() { + rolloverAndDeleteHistoryIndex( + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + ALL_AD_RESULTS_INDEX_PATTERN, + AD_RESULT_HISTORY_INDEX_PATTERN, + ADIndex.RESULT + ); + } + + /** + * Create config index directly. + * + * @param actionListener action called after create index + * @throws IOException IOException from {@link IndexManagement#getConfigMappings} + */ + @Override + public void initConfigIndex(ActionListener actionListener) throws IOException { + super.initConfigIndex(markMappingUpToDate(ADIndex.CONFIG, actionListener)); + } + + /** + * Create job index. + * + * @param actionListener action called after create index + */ + @Override + public void initJobIndex(ActionListener actionListener) { + super.initJobIndex(markMappingUpToDate(ADIndex.JOB, actionListener)); + } + + @Override + protected IndexRequest createDummyIndexRequest(String resultIndex) throws IOException { + AnomalyResult dummyResult = AnomalyResult.getDummyResult(); + return new IndexRequest(resultIndex) + .id(DUMMY_AD_RESULT_ID) + .source(dummyResult.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS)); + } + + @Override + protected DeleteRequest createDummyDeleteRequest(String resultIndex) throws IOException { + return new DeleteRequest(resultIndex).id(DUMMY_AD_RESULT_ID); + } + + @Override + public void initCustomResultIndexDirectly(String resultIndex, ActionListener actionListener) { + initResultIndexDirectly(resultIndex, null, false, AD_RESULT_HISTORY_INDEX_PATTERN, ADIndex.RESULT, actionListener); + } +} diff --git a/src/main/java/org/opensearch/ad/ml/CheckpointDao.java-e b/src/main/java/org/opensearch/ad/ml/CheckpointDao.java-e new file mode 100644 index 000000000..738acd197 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ml/CheckpointDao.java-e @@ -0,0 +1,790 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.time.ZoneOffset; +import java.time.ZonedDateTime; +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Base64; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; + +import org.apache.commons.pool2.impl.GenericObjectPool; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.ExceptionsHelper; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.bulk.BulkAction; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.get.MultiGetAction; +import org.opensearch.action.get.MultiGetRequest; +import org.opensearch.action.get.MultiGetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.client.Client; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.index.reindex.BulkByScrollResponse; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.index.reindex.ScrollableHitSource; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Entity; + +import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.config.Precision; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; +import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper; +import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState; +import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV3StateConverter; +import com.amazon.randomcutforest.state.RandomCutForestMapper; +import com.amazon.randomcutforest.state.RandomCutForestState; +import com.google.gson.Gson; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; + +import io.protostuff.LinkedBuffer; +import io.protostuff.ProtostuffIOUtil; +import io.protostuff.Schema; + +/** + * DAO for model checkpoints. + */ +public class CheckpointDao { + + private static final Logger logger = LogManager.getLogger(CheckpointDao.class); + static final String TIMEOUT_LOG_MSG = "Timeout while deleting checkpoints of"; + static final String BULK_FAILURE_LOG_MSG = "Bulk failure while deleting checkpoints of"; + static final String SEARCH_FAILURE_LOG_MSG = "Search failure while deleting checkpoints of"; + static final String DOC_GOT_DELETED_LOG_MSG = "checkpoints docs get deleted"; + static final String INDEX_DELETED_LOG_MSG = "Checkpoint index has been deleted. Has nothing to do:"; + static final String NOT_ABLE_TO_DELETE_LOG_MSG = "Cannot delete all checkpoints of detector"; + + public static final String ENTITY_RCF = "rcf"; + public static final String ENTITY_THRESHOLD = "th"; + public static final String ENTITY_TRCF = "trcf"; + public static final String FIELD_MODELV2 = "modelV2"; + public static final String DETECTOR_ID = "detectorId"; + + // dependencies + private final Client client; + private final ClientUtil clientUtil; + + // configuration + private final String indexName; + + private Gson gson; + private RandomCutForestMapper mapper; + + // For further reference v1, v2 and v3 refer to the different variations of RCF models + // used by AD. v1 was originally used with the launch of OS 1.0. We later converted to v2 + // which included changes requiring a specific converter from v1 to v2 for BWC. + // v2 models are created by RCF-3.0-rc1 which can be found on maven central. + // v3 is the latest model version form RCF introduced by RCF-3.0-rc2. + // Although this version has a converter method for v2 to v3, after BWC testing it was decided that + // an explicit use of the converter won't be needed as the changes between the models are indeed BWC. + private V1JsonToV3StateConverter converter; + private ThresholdedRandomCutForestMapper trcfMapper; + private Schema trcfSchema; + + private final Class thresholdingModelClass; + + private final ADIndexManagement indexUtil; + private final JsonParser parser = new JsonParser(); + // we won't read/write a checkpoint larger than a threshold + private final int maxCheckpointBytes; + + private final GenericObjectPool serializeRCFBufferPool; + private final int serializeRCFBufferSize; + // anomaly rate + private double anomalyRate; + + /** + * Constructor with dependencies and configuration. + * + * @param client ES search client + * @param clientUtil utility with ES client + * @param indexName name of the index for model checkpoints + * @param gson accessor to Gson functionality + * @param mapper RCF model serialization utility + * @param converter converter from rcf v1 serde to protostuff based format + * @param trcfMapper TRCF serialization mapper + * @param trcfSchema TRCF serialization schema + * @param thresholdingModelClass thresholding model's class + * @param indexUtil Index utility methods + * @param maxCheckpointBytes max checkpoint size in bytes + * @param serializeRCFBufferPool object pool for serializing rcf models + * @param serializeRCFBufferSize the size of the buffer for RCF serialization + * @param anomalyRate anomaly rate + */ + public CheckpointDao( + Client client, + ClientUtil clientUtil, + String indexName, + Gson gson, + RandomCutForestMapper mapper, + V1JsonToV3StateConverter converter, + ThresholdedRandomCutForestMapper trcfMapper, + Schema trcfSchema, + Class thresholdingModelClass, + ADIndexManagement indexUtil, + int maxCheckpointBytes, + GenericObjectPool serializeRCFBufferPool, + int serializeRCFBufferSize, + double anomalyRate + ) { + this.client = client; + this.clientUtil = clientUtil; + this.indexName = indexName; + this.gson = gson; + this.mapper = mapper; + this.converter = converter; + this.trcfMapper = trcfMapper; + this.trcfSchema = trcfSchema; + this.thresholdingModelClass = thresholdingModelClass; + this.indexUtil = indexUtil; + this.maxCheckpointBytes = maxCheckpointBytes; + this.serializeRCFBufferPool = serializeRCFBufferPool; + this.serializeRCFBufferSize = serializeRCFBufferSize; + this.anomalyRate = anomalyRate; + } + + private void saveModelCheckpointSync(Map source, String modelId) { + clientUtil.timedRequest(new IndexRequest(indexName).id(modelId).source(source), logger, client::index); + } + + private void putModelCheckpoint(String modelId, Map source, ActionListener listener) { + if (indexUtil.doesCheckpointIndexExist()) { + saveModelCheckpointAsync(source, modelId, listener); + } else { + onCheckpointNotExist(source, modelId, true, listener); + } + } + + /** + * Puts a rcf model checkpoint in the storage. + * + * @param modelId id of the model + * @param forest the rcf model + * @param listener onResponse is called with null when the operation is completed + */ + public void putTRCFCheckpoint(String modelId, ThresholdedRandomCutForest forest, ActionListener listener) { + Map source = new HashMap<>(); + String modelCheckpoint = toCheckpoint(forest); + if (modelCheckpoint != null) { + source.put(FIELD_MODELV2, modelCheckpoint); + source.put(CommonName.TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC)); + putModelCheckpoint(modelId, source, listener); + } else { + listener.onFailure(new RuntimeException("Fail to create checkpoint to save")); + } + } + + /** + * Puts a thresholding model checkpoint in the storage. + * + * @param modelId id of the model + * @param threshold the thresholding model + * @param listener onResponse is called with null when the operation is completed + */ + public void putThresholdCheckpoint(String modelId, ThresholdingModel threshold, ActionListener listener) { + String modelCheckpoint = AccessController.doPrivileged((PrivilegedAction) () -> gson.toJson(threshold)); + Map source = new HashMap<>(); + source.put(CommonName.FIELD_MODEL, modelCheckpoint); + source.put(CommonName.TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC)); + putModelCheckpoint(modelId, source, listener); + } + + private void onCheckpointNotExist(Map source, String modelId, boolean isAsync, ActionListener listener) { + indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { + if (initResponse.isAcknowledged()) { + if (isAsync) { + saveModelCheckpointAsync(source, modelId, listener); + } else { + saveModelCheckpointSync(source, modelId); + } + } else { + throw new RuntimeException("Creating checkpoint with mappings call not acknowledged."); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + if (isAsync) { + saveModelCheckpointAsync(source, modelId, listener); + } else { + saveModelCheckpointSync(source, modelId); + } + } else { + logger.error(String.format(Locale.ROOT, "Unexpected error creating index %s", indexName), exception); + } + })); + } + + /** + * Update the model doc using fields in source. This ensures we won't touch + * the old checkpoint and nodes with old/new logic can coexist in a cluster. + * This is useful for introducing compact rcf new model format. + * + * @param source fields to update + * @param modelId model Id, used as doc id in the checkpoint index + * @param listener Listener to return response + */ + private void saveModelCheckpointAsync(Map source, String modelId, ActionListener listener) { + + UpdateRequest updateRequest = new UpdateRequest(indexName, modelId); + updateRequest.doc(source); + // If the document does not already exist, the contents of the upsert element are inserted as a new document. + // If the document exists, update fields in the map + updateRequest.docAsUpsert(true); + clientUtil + .asyncRequest( + updateRequest, + client::update, + ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure) + ); + } + + /** + * Prepare for index request using the contents of the given model state + * @param modelState an entity model state + * @return serialized JSON map or empty map if the state is too bloated + * @throws IOException when serialization fails + */ + public Map toIndexSource(ModelState modelState) throws IOException { + String modelId = modelState.getModelId(); + Map source = new HashMap<>(); + EntityModel model = modelState.getModel(); + Optional serializedModel = toCheckpoint(model, modelId); + if (!serializedModel.isPresent() || serializedModel.get().length() > maxCheckpointBytes) { + logger + .warn( + new ParameterizedMessage( + "[{}]'s model is empty or too large: [{}] bytes", + modelState.getModelId(), + serializedModel.isPresent() ? serializedModel.get().length() : 0 + ) + ); + return source; + } + String detectorId = modelState.getId(); + source.put(DETECTOR_ID, detectorId); + // we cannot pass Optional as OpenSearch does not know how to serialize an Optional value + source.put(FIELD_MODELV2, serializedModel.get()); + source.put(CommonName.TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC)); + source.put(CommonName.SCHEMA_VERSION_FIELD, indexUtil.getSchemaVersion(ADIndex.CHECKPOINT)); + Optional entity = model.getEntity(); + if (entity.isPresent()) { + source.put(CommonName.ENTITY_KEY, entity.get()); + } + + return source; + } + + /** + * Serialized an EntityModel + * @param model input model + * @param modelId model id + * @return serialized string + */ + public Optional toCheckpoint(EntityModel model, String modelId) { + return AccessController.doPrivileged((PrivilegedAction>) () -> { + if (model == null) { + logger.warn("Empty model"); + return Optional.empty(); + } + try { + JsonObject json = new JsonObject(); + if (model.getSamples() != null && !(model.getSamples().isEmpty())) { + json.add(CommonName.ENTITY_SAMPLE, gson.toJsonTree(model.getSamples())); + } + if (model.getTrcf().isPresent()) { + json.addProperty(ENTITY_TRCF, toCheckpoint(model.getTrcf().get())); + } + // if json is empty, it will be an empty Json string {}. No need to save it on disk. + return json.entrySet().isEmpty() ? Optional.empty() : Optional.ofNullable(gson.toJson(json)); + } catch (Exception ex) { + logger.warn(new ParameterizedMessage("fail to generate checkpoint for [{}]", modelId), ex); + } + return Optional.empty(); + }); + } + + private String toCheckpoint(ThresholdedRandomCutForest trcf) { + String checkpoint = null; + Map.Entry result = checkoutOrNewBuffer(); + LinkedBuffer buffer = result.getKey(); + boolean needCheckin = result.getValue(); + try { + checkpoint = toCheckpoint(trcf, buffer); + } catch (Exception e) { + logger.error("Failed to serialize model", e); + if (needCheckin) { + try { + serializeRCFBufferPool.invalidateObject(buffer); + needCheckin = false; + } catch (Exception x) { + logger.warn("Failed to invalidate buffer", x); + } + try { + checkpoint = toCheckpoint(trcf, LinkedBuffer.allocate(serializeRCFBufferSize)); + } catch (Exception ex) { + logger.warn("Failed to generate checkpoint", ex); + } + } + } finally { + if (needCheckin) { + try { + serializeRCFBufferPool.returnObject(buffer); + } catch (Exception e) { + logger.warn("Failed to return buffer to pool", e); + } + } + } + return checkpoint; + } + + private Map.Entry checkoutOrNewBuffer() { + LinkedBuffer buffer = null; + boolean isCheckout = true; + try { + buffer = serializeRCFBufferPool.borrowObject(); + } catch (Exception e) { + logger.warn("Failed to borrow a buffer from pool", e); + } + if (buffer == null) { + buffer = LinkedBuffer.allocate(serializeRCFBufferSize); + isCheckout = false; + } + return new SimpleImmutableEntry(buffer, isCheckout); + } + + private String toCheckpoint(ThresholdedRandomCutForest trcf, LinkedBuffer buffer) { + try { + byte[] bytes = AccessController.doPrivileged((PrivilegedAction) () -> { + ThresholdedRandomCutForestState trcfState = trcfMapper.toState(trcf); + return ProtostuffIOUtil.toByteArray(trcfState, trcfSchema, buffer); + }); + return Base64.getEncoder().encodeToString(bytes); + } finally { + buffer.clear(); + } + } + + /** + * Deletes the model checkpoint for the model. + * + * @param modelId id of the model + * @param listener onReponse is called with null when the operation is completed + */ + public void deleteModelCheckpoint(String modelId, ActionListener listener) { + clientUtil + .asyncRequest( + new DeleteRequest(indexName, modelId), + client::delete, + ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure) + ); + } + + /** + * Delete checkpoints associated with a detector. Used in multi-entity detector. + * @param detectorID Detector Id + */ + public void deleteModelCheckpointByDetectorId(String detectorID) { + // A bulk delete request is performed for each batch of matching documents. If a + // search or bulk request is rejected, the requests are retried up to 10 times, + // with exponential back off. If the maximum retry limit is reached, processing + // halts and all failed requests are returned in the response. Any delete + // requests that completed successfully still stick, they are not rolled back. + DeleteByQueryRequest deleteRequest = new DeleteByQueryRequest(ADCommonName.CHECKPOINT_INDEX_NAME) + .setQuery(new MatchQueryBuilder(DETECTOR_ID, detectorID)) + .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN) + .setAbortOnVersionConflict(false) // when current delete happens, previous might not finish. + // Retry in this case + .setRequestsPerSecond(500); // throttle delete requests + logger.info("Delete checkpoints of detector {}", detectorID); + client.execute(DeleteByQueryAction.INSTANCE, deleteRequest, ActionListener.wrap(response -> { + if (response.isTimedOut() || !response.getBulkFailures().isEmpty() || !response.getSearchFailures().isEmpty()) { + logFailure(response, detectorID); + } + // can return 0 docs get deleted because: + // 1) we cannot find matching docs + // 2) bad stats from OpenSearch. In this case, docs are deleted, but + // OpenSearch says deleted is 0. + logger.info("{} " + DOC_GOT_DELETED_LOG_MSG, response.getDeleted()); + }, exception -> { + if (exception instanceof IndexNotFoundException) { + logger.info(INDEX_DELETED_LOG_MSG + " {}", detectorID); + } else { + // Gonna eventually delete in daily cron. + logger.error(NOT_ABLE_TO_DELETE_LOG_MSG, exception); + } + })); + } + + private void logFailure(BulkByScrollResponse response, String detectorID) { + if (response.isTimedOut()) { + logger.warn(TIMEOUT_LOG_MSG + " {}", detectorID); + } else if (!response.getBulkFailures().isEmpty()) { + logger.warn(BULK_FAILURE_LOG_MSG + " {}", detectorID); + for (BulkItemResponse.Failure bulkFailure : response.getBulkFailures()) { + logger.warn(bulkFailure); + } + } else { + logger.warn(SEARCH_FAILURE_LOG_MSG + " {}", detectorID); + for (ScrollableHitSource.SearchFailure searchFailure : response.getSearchFailures()) { + logger.warn(searchFailure); + } + } + } + + /** + * Load json checkpoint into models + * + * @param checkpoint json checkpoint contents + * @param modelId Model Id + * @return a pair of entity model and its last checkpoint time; or empty if + * the raw checkpoint is too large + */ + public Optional> fromEntityModelCheckpoint(Map checkpoint, String modelId) { + try { + return AccessController.doPrivileged((PrivilegedAction>>) () -> { + Object modelObj = checkpoint.get(FIELD_MODELV2); + if (modelObj == null) { + // in case there is old -format checkpoint + modelObj = checkpoint.get(CommonName.FIELD_MODEL); + } + if (modelObj == null) { + logger.warn(new ParameterizedMessage("Empty model for [{}]", modelId)); + return Optional.empty(); + } + String model = (String) modelObj; + if (model.length() > maxCheckpointBytes) { + logger.warn(new ParameterizedMessage("[{}]'s model too large: [{}] bytes", modelId, model.length())); + return Optional.empty(); + } + JsonObject json = parser.parse(model).getAsJsonObject(); + ArrayDeque samples = null; + if (json.has(CommonName.ENTITY_SAMPLE)) { + // verified, don't need privileged call to get permission + samples = new ArrayDeque<>( + Arrays.asList(this.gson.fromJson(json.getAsJsonArray(CommonName.ENTITY_SAMPLE), new double[0][0].getClass())) + ); + } else { + // avoid possible null pointer exception + samples = new ArrayDeque<>(); + } + ThresholdedRandomCutForest trcf = null; + + if (json.has(ENTITY_TRCF)) { + trcf = toTrcf(json.getAsJsonPrimitive(ENTITY_TRCF).getAsString()); + } else { + Optional rcf = Optional.empty(); + Optional threshold = Optional.empty(); + if (json.has(ENTITY_RCF)) { + String serializedRCF = json.getAsJsonPrimitive(ENTITY_RCF).getAsString(); + rcf = deserializeRCFModel(serializedRCF, modelId); + } + if (json.has(ENTITY_THRESHOLD)) { + // verified, don't need privileged call to get permission + threshold = Optional + .ofNullable( + this.gson.fromJson(json.getAsJsonPrimitive(ENTITY_THRESHOLD).getAsString(), thresholdingModelClass) + ); + } + + Optional convertedTRCF = convertToTRCF(rcf, threshold); + // if checkpoint is corrupted (e.g., some unexpected checkpoint when we missed + // the mark in backward compatibility), we are not gonna load the model part + // the model will have to use live data to initialize + if (convertedTRCF.isPresent()) { + trcf = convertedTRCF.get(); + } + } + + String lastCheckpointTimeString = (String) (checkpoint.get(CommonName.TIMESTAMP)); + Instant timestamp = Instant.parse(lastCheckpointTimeString); + Entity entity = null; + Object serializedEntity = checkpoint.get(CommonName.ENTITY_KEY); + if (serializedEntity != null) { + try { + entity = Entity.fromJsonArray(serializedEntity); + } catch (Exception e) { + logger.error(new ParameterizedMessage("fail to parse entity", serializedEntity), e); + } + } + EntityModel entityModel = new EntityModel(entity, samples, trcf); + return Optional.of(new SimpleImmutableEntry<>(entityModel, timestamp)); + }); + } catch (Exception e) { + logger.warn("Exception while deserializing checkpoint " + modelId, e); + // checkpoint corrupted (e.g., a checkpoint not recognized by current code + // due to bugs). Better redo training. + return Optional.empty(); + } + } + + ThresholdedRandomCutForest toTrcf(String checkpoint) { + ThresholdedRandomCutForest trcf = null; + if (checkpoint != null && !checkpoint.isEmpty()) { + try { + byte[] bytes = Base64.getDecoder().decode(checkpoint); + ThresholdedRandomCutForestState state = trcfSchema.newMessage(); + AccessController.doPrivileged((PrivilegedAction) () -> { + ProtostuffIOUtil.mergeFrom(bytes, state, trcfSchema); + return null; + }); + trcf = trcfMapper.toModel(state); + } catch (RuntimeException e) { + logger.error("Failed to deserialize TRCF model", e); + } + } + return trcf; + } + + private Optional deserializeRCFModel(String checkpoint, String modelId) { + if (checkpoint == null || checkpoint.isEmpty()) { + return Optional.empty(); + } + return Optional.ofNullable(AccessController.doPrivileged((PrivilegedAction) () -> { + try { + RandomCutForestState state = converter.convert(checkpoint, Precision.FLOAT_32); + return mapper.toModel(state); + } catch (Exception e) { + logger.error("Unexpected error when deserializing " + modelId, e); + return null; + } + })); + } + + private void deserializeTRCFModel( + GetResponse response, + String rcfModelId, + ActionListener> listener + ) { + Object model = null; + if (response.isExists()) { + try { + model = response.getSource().get(FIELD_MODELV2); + if (model != null) { + listener.onResponse(Optional.ofNullable(toTrcf((String) model))); + } else { + Object modelV1 = response.getSource().get(CommonName.FIELD_MODEL); + Optional forest = deserializeRCFModel((String) modelV1, rcfModelId); + if (!forest.isPresent()) { + logger.error("Unexpected error when deserializing [{}]", rcfModelId); + listener.onResponse(Optional.empty()); + return; + } + String thresholdingModelId = SingleStreamModelIdMapper.getThresholdModelIdFromRCFModelId(rcfModelId); + // query for threshold model and combinne rcf and threshold model into a ThresholdedRandomCutForest + getThresholdModel( + thresholdingModelId, + ActionListener + .wrap( + thresholdingModel -> { listener.onResponse(convertToTRCF(forest, thresholdingModel)); }, + listener::onFailure + ) + ); + } + } catch (Exception e) { + logger.error(new ParameterizedMessage("Unexpected error when deserializing [{}]", rcfModelId), e); + listener.onResponse(Optional.empty()); + } + } else { + listener.onResponse(Optional.empty()); + } + } + + /** + * Read a checkpoint from the index and return the EntityModel object + * @param modelId Model Id + * @param listener Listener to return a pair of entity model and its last checkpoint time + */ + public void deserializeModelCheckpoint(String modelId, ActionListener>> listener) { + clientUtil + .asyncRequest( + new GetRequest(indexName, modelId), + client::get, + ActionListener.wrap(response -> { listener.onResponse(processGetResponse(response, modelId)); }, listener::onFailure) + ); + } + + /** + * Process a checkpoint GetResponse and return the EntityModel object + * @param response Checkpoint Index GetResponse + * @param modelId Model Id + * @return a pair of entity model and its last checkpoint time + */ + public Optional> processGetResponse(GetResponse response, String modelId) { + Optional> checkpointString = processRawCheckpoint(response); + if (checkpointString.isPresent()) { + return fromEntityModelCheckpoint(checkpointString.get(), modelId); + } else { + return Optional.empty(); + } + } + + /** + * Returns to listener the checkpoint for the rcf model. + * + * @param modelId id of the model + * @param listener onResponse is called with the model checkpoint, or empty for no such model + */ + public void getTRCFModel(String modelId, ActionListener> listener) { + clientUtil + .asyncRequest( + new GetRequest(indexName, modelId), + client::get, + ActionListener + .wrap( + response -> deserializeTRCFModel(response, modelId, listener), + exception -> { + // expected exception, don't print stack trace + if (exception instanceof IndexNotFoundException) { + listener.onResponse(Optional.empty()); + } else { + listener.onFailure(exception); + } + } + ) + ); + } + + /** + * Returns to listener the checkpoint for the threshold model. + * + * @param modelId id of the model + * @param listener onResponse is called with the model checkpoint, or empty for no such model + */ + public void getThresholdModel(String modelId, ActionListener> listener) { + clientUtil.asyncRequest(new GetRequest(indexName, modelId), client::get, ActionListener.wrap(response -> { + Optional thresholdCheckpoint = processThresholdModelCheckpoint(response); + if (!thresholdCheckpoint.isPresent()) { + listener.onFailure(new ResourceNotFoundException("", "Fail to find model " + modelId)); + return; + } + Optional model = thresholdCheckpoint + .map( + checkpoint -> AccessController + .doPrivileged( + (PrivilegedAction) () -> gson.fromJson((String) checkpoint, thresholdingModelClass) + ) + ); + listener.onResponse(model); + }, + exception -> { + // expected exception, don't print stack trace + if (exception instanceof IndexNotFoundException) { + listener.onResponse(Optional.empty()); + } else { + listener.onFailure(exception); + } + } + )); + } + + private Optional processThresholdModelCheckpoint(GetResponse response) { + return Optional + .ofNullable(response) + .filter(GetResponse::isExists) + .map(GetResponse::getSource) + .map(source -> source.get(CommonName.FIELD_MODEL)); + } + + private Optional> processRawCheckpoint(GetResponse response) { + return Optional.ofNullable(response).filter(GetResponse::isExists).map(GetResponse::getSource); + } + + public void batchRead(MultiGetRequest request, ActionListener listener) { + clientUtil.execute(MultiGetAction.INSTANCE, request, listener); + } + + public void batchWrite(BulkRequest request, ActionListener listener) { + if (indexUtil.doesCheckpointIndexExist()) { + clientUtil.execute(BulkAction.INSTANCE, request, listener); + } else { + indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { + if (initResponse.isAcknowledged()) { + clientUtil.execute(BulkAction.INSTANCE, request, listener); + } else { + // create index failure. Notify callers using listener. + listener.onFailure(new TimeSeriesException("Creating checkpoint with mappings call not acknowledged.")); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + clientUtil.execute(BulkAction.INSTANCE, request, listener); + } else { + logger.error(String.format(Locale.ROOT, "Unexpected error creating checkpoint index"), exception); + listener.onFailure(exception); + } + })); + } + } + + private Optional convertToTRCF(Optional rcf, Optional kllThreshold) { + if (!rcf.isPresent()) { + return Optional.empty(); + } + // if there is no threshold model (e.g., threshold model is deleted by HourlyCron), we are gonna + // start with empty list of rcf scores + List scores = new ArrayList<>(); + if (kllThreshold.isPresent()) { + scores = kllThreshold.get().extractScores(); + } + return Optional.of(new ThresholdedRandomCutForest(rcf.get(), anomalyRate, scores)); + } + + /** + * Should we save the checkpoint or not + * @param lastCheckpointTIme Last checkpoint time + * @param forceWrite Save no matter what + * @param checkpointInterval Checkpoint interval + * @param clock UTC clock + * + * @return true when forceWrite is true or we haven't saved checkpoint in the + * last checkpoint interval; false otherwise + */ + public boolean shouldSave(Instant lastCheckpointTIme, boolean forceWrite, Duration checkpointInterval, Clock clock) { + return (lastCheckpointTIme != Instant.MIN && lastCheckpointTIme.plus(checkpointInterval).isBefore(clock.instant())) || forceWrite; + } +} diff --git a/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java b/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java index 242e95622..3f198285f 100644 --- a/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java +++ b/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java @@ -37,7 +37,6 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.action.ActionListener; import org.opensearch.action.support.ThreadedActionListener; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.CleanState; import org.opensearch.ad.MaintenanceState; import org.opensearch.ad.NodeStateManager; @@ -51,6 +50,7 @@ import org.opensearch.ad.util.ExceptionUtil; import org.opensearch.common.settings.Settings; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.dataprocessor.Imputer; @@ -310,7 +310,7 @@ private void coldStart( }); threadPool - .executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME) + .executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME) .execute( () -> getEntityColdStartData( detectorId, @@ -318,7 +318,7 @@ private void coldStart( new ThreadedActionListener<>( logger, threadPool, - AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, coldStartCallBack, false ) @@ -444,7 +444,7 @@ private void getEntityColdStartData(String detectorId, Entity entity, ActionList .getEntityMinDataTime( detector, entity, - new ThreadedActionListener<>(logger, threadPool, AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, minTimeListener, false) + new ThreadedActionListener<>(logger, threadPool, TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, minTimeListener, false) ); }, listener::onFailure); @@ -452,7 +452,7 @@ private void getEntityColdStartData(String detectorId, Entity entity, ActionList nodeStateManager .getAnomalyDetector( detectorId, - new ThreadedActionListener<>(logger, threadPool, AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, getDetectorListener, false) + new ThreadedActionListener<>(logger, threadPool, TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, getDetectorListener, false) ); } @@ -559,7 +559,13 @@ private void getFeatures( // metric is ill-formed, but that cannot be solved by cold-start strategy of the AD plugin — if we attempt to do // that, we will have issues with legitimate interpretations of 0. true, - new ThreadedActionListener<>(logger, threadPool, AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, getFeaturelistener, false) + new ThreadedActionListener<>( + logger, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + getFeaturelistener, + false + ) ); } catch (Exception e) { listener.onFailure(e); diff --git a/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java-e b/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java-e new file mode 100644 index 000000000..3f198285f --- /dev/null +++ b/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java-e @@ -0,0 +1,755 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.COOLDOWN_MINUTES; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Queue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.core.util.Throwables; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ThreadedActionListener; +import org.opensearch.ad.CleanState; +import org.opensearch.ad.MaintenanceState; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.caching.DoorKeeper; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.ratelimit.CheckpointWriteWorker; +import org.opensearch.ad.ratelimit.RequestPriority; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.ad.util.ExceptionUtil; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.dataprocessor.Imputer; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + +import com.amazon.randomcutforest.config.Precision; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * Training models for HCAD detectors + * + */ +public class EntityColdStarter implements MaintenanceState, CleanState { + private static final Logger logger = LogManager.getLogger(EntityColdStarter.class); + private final Clock clock; + private final ThreadPool threadPool; + private final NodeStateManager nodeStateManager; + private final int rcfSampleSize; + private final int numberOfTrees; + private final double rcfTimeDecay; + private final int numMinSamples; + private final double thresholdMinPvalue; + private final int defaulStrideLength; + private final int defaultNumberOfSamples; + private final Imputer imputer; + private final SearchFeatureDao searchFeatureDao; + private Instant lastThrottledColdStartTime; + private final FeatureManager featureManager; + private int coolDownMinutes; + // A bloom filter checked before cold start to ensure we don't repeatedly + // retry cold start of the same model. + // keys are detector ids. + private Map doorKeepers; + private final Duration modelTtl; + private final CheckpointWriteWorker checkpointWriteQueue; + // make sure rcf use a specific random seed. Otherwise, we will use a random random (not a typo) seed. + // this is mainly used for testing to make sure the model we trained and the reference rcf produce + // the same results + private final long rcfSeed; + private final int maxRoundofColdStart; + private final double initialAcceptFraction; + + /** + * Constructor + * + * @param clock UTC clock + * @param threadPool Accessor to different threadpools + * @param nodeStateManager Storing node state + * @param rcfSampleSize The sample size used by stream samplers in this forest + * @param numberOfTrees The number of trees in this forest. + * @param rcfTimeDecay rcf samples time decay constant + * @param numMinSamples The number of points required by stream samplers before + * results are returned. + * @param defaultSampleStride default sample distances measured in detector intervals. + * @param defaultTrainSamples Default train samples to collect. + * @param imputer Used to generate data points between samples. + * @param searchFeatureDao Used to issue ES queries. + * @param thresholdMinPvalue min P-value for thresholding + * @param featureManager Used to create features for models. + * @param settings ES settings accessor + * @param modelTtl time-to-live before last access time of the cold start cache. + * We have a cache to record entities that have run cold starts to avoid + * repeated unsuccessful cold start. + * @param checkpointWriteQueue queue to insert model checkpoints + * @param rcfSeed rcf random seed + * @param maxRoundofColdStart max number of rounds of cold start + */ + public EntityColdStarter( + Clock clock, + ThreadPool threadPool, + NodeStateManager nodeStateManager, + int rcfSampleSize, + int numberOfTrees, + double rcfTimeDecay, + int numMinSamples, + int defaultSampleStride, + int defaultTrainSamples, + Imputer imputer, + SearchFeatureDao searchFeatureDao, + double thresholdMinPvalue, + FeatureManager featureManager, + Settings settings, + Duration modelTtl, + CheckpointWriteWorker checkpointWriteQueue, + long rcfSeed, + int maxRoundofColdStart + ) { + this.clock = clock; + this.lastThrottledColdStartTime = Instant.MIN; + this.threadPool = threadPool; + this.nodeStateManager = nodeStateManager; + this.rcfSampleSize = rcfSampleSize; + this.numberOfTrees = numberOfTrees; + this.rcfTimeDecay = rcfTimeDecay; + this.numMinSamples = numMinSamples; + this.defaulStrideLength = defaultSampleStride; + this.defaultNumberOfSamples = defaultTrainSamples; + this.imputer = imputer; + this.searchFeatureDao = searchFeatureDao; + this.thresholdMinPvalue = thresholdMinPvalue; + this.featureManager = featureManager; + this.coolDownMinutes = (int) (COOLDOWN_MINUTES.get(settings).getMinutes()); + this.doorKeepers = new ConcurrentHashMap<>(); + this.modelTtl = modelTtl; + this.checkpointWriteQueue = checkpointWriteQueue; + this.rcfSeed = rcfSeed; + this.maxRoundofColdStart = maxRoundofColdStart; + this.initialAcceptFraction = numMinSamples * 1.0d / rcfSampleSize; + } + + public EntityColdStarter( + Clock clock, + ThreadPool threadPool, + NodeStateManager nodeStateManager, + int rcfSampleSize, + int numberOfTrees, + double rcfTimeDecay, + int numMinSamples, + int maxSampleStride, + int maxTrainSamples, + Imputer imputer, + SearchFeatureDao searchFeatureDao, + double thresholdMinPvalue, + FeatureManager featureManager, + Settings settings, + Duration modelTtl, + CheckpointWriteWorker checkpointWriteQueue, + int maxRoundofColdStart + ) { + this( + clock, + threadPool, + nodeStateManager, + rcfSampleSize, + numberOfTrees, + rcfTimeDecay, + numMinSamples, + maxSampleStride, + maxTrainSamples, + imputer, + searchFeatureDao, + thresholdMinPvalue, + featureManager, + settings, + modelTtl, + checkpointWriteQueue, + -1, + maxRoundofColdStart + ); + } + + /** + * Training model for an entity + * @param modelId model Id corresponding to the entity + * @param entity the entity's information + * @param detectorId the detector Id corresponding to the entity + * @param modelState model state associated with the entity + * @param listener call back to call after cold start + */ + private void coldStart( + String modelId, + Entity entity, + String detectorId, + ModelState modelState, + AnomalyDetector detector, + ActionListener listener + ) { + logger.debug("Trigger cold start for {}", modelId); + + if (modelState == null || entity == null) { + listener + .onFailure( + new IllegalArgumentException( + String + .format( + Locale.ROOT, + "Cannot have empty model state or entity: model state [%b], entity [%b]", + modelState == null, + entity == null + ) + ) + ); + return; + } + + if (lastThrottledColdStartTime.plus(Duration.ofMinutes(coolDownMinutes)).isAfter(clock.instant())) { + listener.onResponse(null); + return; + } + + boolean earlyExit = true; + try { + DoorKeeper doorKeeper = doorKeepers + .computeIfAbsent( + detectorId, + id -> { + // reset every 60 intervals + return new DoorKeeper( + TimeSeriesSettings.DOOR_KEEPER_FOR_COLD_STARTER_MAX_INSERTION, + TimeSeriesSettings.DOOR_KEEPER_FALSE_POSITIVE_RATE, + detector.getIntervalDuration().multipliedBy(TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ), + clock + ); + } + ); + + // Won't retry cold start within 60 intervals for an entity + if (doorKeeper.mightContain(modelId)) { + return; + } + + doorKeeper.put(modelId); + + ActionListener>> coldStartCallBack = ActionListener.wrap(trainingData -> { + try { + if (trainingData.isPresent()) { + List dataPoints = trainingData.get(); + extractTrainSamples(dataPoints, modelId, modelState); + Queue samples = modelState.getModel().getSamples(); + // only train models if we have enough samples + if (samples.size() >= numMinSamples) { + // The function trainModelFromDataSegments will save a trained a model. trainModelFromDataSegments is called by + // multiple places so I want to make the saving model implicit just in case I forgot. + trainModelFromDataSegments(samples, entity, modelState, detector.getShingleSize()); + logger.info("Succeeded in training entity: {}", modelId); + } else { + // save to checkpoint + checkpointWriteQueue.write(modelState, true, RequestPriority.MEDIUM); + logger.info("Not enough data to train entity: {}, currently we have {}", modelId, samples.size()); + } + } else { + logger.info("Cannot get training data for {}", modelId); + } + listener.onResponse(null); + } catch (Exception e) { + listener.onFailure(e); + } + }, exception -> { + try { + logger.error(new ParameterizedMessage("Error while cold start {}", modelId), exception); + Throwable cause = Throwables.getRootCause(exception); + if (ExceptionUtil.isOverloaded(cause)) { + logger.error("too many requests"); + lastThrottledColdStartTime = Instant.now(); + } else if (cause instanceof TimeSeriesException || exception instanceof TimeSeriesException) { + // e.g., cannot find anomaly detector + nodeStateManager.setException(detectorId, exception); + } else { + nodeStateManager.setException(detectorId, new TimeSeriesException(detectorId, cause)); + } + listener.onFailure(exception); + } catch (Exception e) { + listener.onFailure(e); + } + }); + + threadPool + .executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME) + .execute( + () -> getEntityColdStartData( + detectorId, + entity, + new ThreadedActionListener<>( + logger, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + coldStartCallBack, + false + ) + ) + ); + earlyExit = false; + } finally { + if (earlyExit) { + listener.onResponse(null); + } + } + } + + /** + * Train model using given data points and save the trained model. + * + * @param dataPoints Queue of continuous data points, in ascending order of timestamps + * @param entity Entity instance + * @param entityState Entity state associated with the model Id + */ + private void trainModelFromDataSegments( + Queue dataPoints, + Entity entity, + ModelState entityState, + int shingleSize + ) { + if (dataPoints == null || dataPoints.size() == 0) { + throw new IllegalArgumentException("Data points must not be empty."); + } + + double[] firstPoint = dataPoints.peek(); + if (firstPoint == null || firstPoint.length == 0) { + throw new IllegalArgumentException("Data points must not be empty."); + } + int dimensions = firstPoint.length * shingleSize; + ThresholdedRandomCutForest.Builder rcfBuilder = ThresholdedRandomCutForest + .builder() + .dimensions(dimensions) + .sampleSize(rcfSampleSize) + .numberOfTrees(numberOfTrees) + .timeDecay(rcfTimeDecay) + .outputAfter(numMinSamples) + .initialAcceptFraction(initialAcceptFraction) + .parallelExecutionEnabled(false) + .compact(true) + .precision(Precision.FLOAT_32) + .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) + // same with dimension for opportunistic memory saving + // Usually, we use it as shingleSize(dimension). When a new point comes in, we will + // look at the point store if there is any overlapping. Say the previously-stored + // vector is x1, x2, x3, x4, now we add x3, x4, x5, x6. RCF will recognize + // overlapping x3, x4, and only store x5, x6. + .shingleSize(shingleSize) + .internalShinglingEnabled(true) + .anomalyRate(1 - this.thresholdMinPvalue); + + if (rcfSeed > 0) { + rcfBuilder.randomSeed(rcfSeed); + } + ThresholdedRandomCutForest trcf = new ThresholdedRandomCutForest(rcfBuilder); + + while (!dataPoints.isEmpty()) { + trcf.process(dataPoints.poll(), 0); + } + + EntityModel model = entityState.getModel(); + if (model == null) { + model = new EntityModel(entity, new ArrayDeque<>(), null); + } + model.setTrcf(trcf); + + entityState.setLastUsedTime(clock.instant()); + + // save to checkpoint + checkpointWriteQueue.write(entityState, true, RequestPriority.MEDIUM); + } + + /** + * Get training data for an entity. + * + * We first note the maximum and minimum timestamp, and sample at most 24 points + * (with 60 points apart between two neighboring samples) between those minimum + * and maximum timestamps. Samples can be missing. We only interpolate points + * between present neighboring samples. We then transform samples and interpolate + * points to shingles. Finally, full shingles will be used for cold start. + * + * @param detectorId detector Id + * @param entity the entity's information + * @param listener listener to return training data + */ + private void getEntityColdStartData(String detectorId, Entity entity, ActionListener>> listener) { + ActionListener> getDetectorListener = ActionListener.wrap(detectorOp -> { + if (!detectorOp.isPresent()) { + listener.onFailure(new EndRunException(detectorId, "AnomalyDetector is not available.", false)); + return; + } + List coldStartData = new ArrayList<>(); + AnomalyDetector detector = detectorOp.get(); + + ActionListener> minTimeListener = ActionListener.wrap(earliest -> { + if (earliest.isPresent()) { + long startTimeMs = earliest.get().longValue(); + + // End time uses milliseconds as start time is assumed to be in milliseconds. + // Opensearch uses a set of preconfigured formats to recognize and parse these + // strings into a long value + // representing milliseconds-since-the-epoch in UTC. + // More on https://tinyurl.com/wub4fk92 + + long endTimeMs = clock.millis(); + Pair params = selectRangeParam(detector); + int stride = params.getLeft(); + int numberOfSamples = params.getRight(); + + // we start with round 0 + getFeatures(listener, 0, coldStartData, detector, entity, stride, numberOfSamples, startTimeMs, endTimeMs); + } else { + listener.onResponse(Optional.empty()); + } + }, listener::onFailure); + + searchFeatureDao + .getEntityMinDataTime( + detector, + entity, + new ThreadedActionListener<>(logger, threadPool, TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, minTimeListener, false) + ); + + }, listener::onFailure); + + nodeStateManager + .getAnomalyDetector( + detectorId, + new ThreadedActionListener<>(logger, threadPool, TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, getDetectorListener, false) + ); + } + + private void getFeatures( + ActionListener>> listener, + int round, + List lastRoundColdStartData, + AnomalyDetector detector, + Entity entity, + int stride, + int numberOfSamples, + long startTimeMs, + long endTimeMs + ) { + if (startTimeMs >= endTimeMs || endTimeMs - startTimeMs < detector.getIntervalInMilliseconds()) { + listener.onResponse(Optional.of(lastRoundColdStartData)); + return; + } + + // create ranges in desending order, we will reorder it in ascending order + // in Opensearch's response + List> sampleRanges = getTrainSampleRanges(detector, startTimeMs, endTimeMs, stride, numberOfSamples); + + if (sampleRanges.isEmpty()) { + listener.onResponse(Optional.of(lastRoundColdStartData)); + return; + } + + ActionListener>> getFeaturelistener = ActionListener.wrap(featureSamples -> { + // storing lastSample = null; + List currentRoundColdStartData = new ArrayList<>(); + + // featuresSamples are in ascending order of time. + for (int i = 0; i < featureSamples.size(); i++) { + Optional featuresOptional = featureSamples.get(i); + if (featuresOptional.isPresent()) { + // we only need the most recent two samples + // For the missing samples we use linear interpolation as well. + // Denote the Samples S0, S1, ... as samples in reverse order of time. + // Each [Si​,Si−1​]corresponds to strideLength * detector interval. + // If we got samples for S0, S1, S4 (both S2 and S3 are missing), then + // we interpolate the [S4,S1] into 3*strideLength pieces. + if (lastSample != null) { + // right sample has index i and feature featuresOptional.get() + int numInterpolants = (i - lastSample.getLeft()) * stride + 1; + double[][] points = featureManager + .transpose( + imputer + .impute( + featureManager.transpose(new double[][] { lastSample.getRight(), featuresOptional.get() }), + numInterpolants + ) + ); + // the last point will be included in the next iteration or we process + // it in the end. We don't want to repeatedly include the samples twice. + currentRoundColdStartData.add(Arrays.copyOfRange(points, 0, points.length - 1)); + } + lastSample = Pair.of(i, featuresOptional.get()); + } + } + + if (lastSample != null) { + currentRoundColdStartData.add(new double[][] { lastSample.getRight() }); + } + if (lastRoundColdStartData.size() > 0) { + currentRoundColdStartData.addAll(lastRoundColdStartData); + } + + // If the first round of probe provides (32+shingleSize) points (note that if S0 is + // missing or all Si​ for some i > N is missing then we would miss a lot of points. + // Otherwise we can issue another round of query — if there is any sample in the + // second round then we would have 32 + shingleSize points. If there is no sample + // in the second round then we should wait for real data. + if (calculateColdStartDataSize(currentRoundColdStartData) >= detector.getShingleSize() + numMinSamples + || round + 1 >= maxRoundofColdStart) { + listener.onResponse(Optional.of(currentRoundColdStartData)); + } else { + // the last sample's start time is the endTimeMs of next round of probe. + long lastSampleStartTime = sampleRanges.get(sampleRanges.size() - 1).getKey(); + getFeatures( + listener, + round + 1, + currentRoundColdStartData, + detector, + entity, + stride, + numberOfSamples, + startTimeMs, + lastSampleStartTime + ); + } + }, listener::onFailure); + + try { + searchFeatureDao + .getColdStartSamplesForPeriods( + detector, + sampleRanges, + entity, + // Accept empty bucket. + // 0, as returned by the engine should constitute a valid answer, “null” is a missing answer — it may be that 0 + // is meaningless in some case, but 0 is also meaningful in some cases. It may be that the query defining the + // metric is ill-formed, but that cannot be solved by cold-start strategy of the AD plugin — if we attempt to do + // that, we will have issues with legitimate interpretations of 0. + true, + new ThreadedActionListener<>( + logger, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + getFeaturelistener, + false + ) + ); + } catch (Exception e) { + listener.onFailure(e); + } + } + + private int calculateColdStartDataSize(List coldStartData) { + int size = 0; + for (int i = 0; i < coldStartData.size(); i++) { + size += coldStartData.get(i).length; + } + return size; + } + + /** + * Select strideLength and numberOfSamples, where stride is the number of intervals + * between two samples and trainSamples is training samples to fetch. If we disable + * interpolation, strideLength is 1 and numberOfSamples is shingleSize + numMinSamples; + * + * Algorithm: + * + * delta is the length of the detector interval in minutes. + * + * 1. Suppose delta ≤ 30 and divides 60. Then set numberOfSamples = ceil ( (shingleSize + 32)/ 24 )*24 + * and strideLength = 60/delta. Note that if there is enough data — we may have lot more than shingleSize+32 + * points — which is only good. This step tries to match data with hourly pattern. + * 2. otherwise, set numberOfSamples = (shingleSize + 32) and strideLength = 1. + * This should be an uncommon case as we are assuming most users think in terms of multiple of 5 minutes + *(say 10 or 30 minutes). But if someone wants a 23 minutes interval —- and the system permits -- + * we give it to them. In this case, we disable interpolation as we want to interpolate based on the hourly pattern. + * That's why we use 60 as a dividend in case 1. The 23 minute case does not fit that pattern. + * Note the smallest delta that does not divide 60 is 7 which is quite large to wait for one data point. + * @return the chosen strideLength and numberOfSamples + */ + private Pair selectRangeParam(AnomalyDetector detector) { + int shingleSize = detector.getShingleSize(); + if (ADEnabledSetting.isInterpolationInColdStartEnabled()) { + long delta = detector.getIntervalInMinutes(); + + int strideLength = defaulStrideLength; + int numberOfSamples = defaultNumberOfSamples; + if (delta <= 30 && 60 % delta == 0) { + strideLength = (int) (60 / delta); + numberOfSamples = (int) Math.ceil((shingleSize + numMinSamples) / 24.0d) * 24; + } else { + strideLength = 1; + numberOfSamples = shingleSize + numMinSamples; + } + return Pair.of(strideLength, numberOfSamples); + } else { + return Pair.of(1, shingleSize + numMinSamples); + } + + } + + /** + * Get train samples within a time range. + * + * @param detector accessor to detector config + * @param startMilli range start + * @param endMilli range end + * @param stride the number of intervals between two samples + * @param numberOfSamples maximum training samples to fetch + * @return list of sample time ranges + */ + private List> getTrainSampleRanges( + AnomalyDetector detector, + long startMilli, + long endMilli, + int stride, + int numberOfSamples + ) { + long bucketSize = ((IntervalTimeConfiguration) detector.getInterval()).toDuration().toMillis(); + int numBuckets = (int) Math.floor((endMilli - startMilli) / (double) bucketSize); + // adjust if numStrides is more than the max samples + int numStrides = Math.min((int) Math.floor(numBuckets / (double) stride), numberOfSamples); + List> sampleRanges = Stream + .iterate(endMilli, i -> i - stride * bucketSize) + .limit(numStrides) + .map(time -> new SimpleImmutableEntry<>(time - bucketSize, time)) + .collect(Collectors.toList()); + return sampleRanges; + } + + /** + * Train models for the given entity + * @param entity The entity info + * @param detectorId Detector Id + * @param modelState Model state associated with the entity + * @param listener callback before the method returns whenever EntityColdStarter + * finishes training or encounters exceptions. The listener helps notify the + * cold start queue to pull another request (if any) to execute. + */ + public void trainModel(Entity entity, String detectorId, ModelState modelState, ActionListener listener) { + nodeStateManager.getAnomalyDetector(detectorId, ActionListener.wrap(detectorOptional -> { + if (false == detectorOptional.isPresent()) { + logger.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); + listener.onFailure(new TimeSeriesException(detectorId, "fail to find detector")); + return; + } + + AnomalyDetector detector = detectorOptional.get(); + + Queue samples = modelState.getModel().getSamples(); + String modelId = modelState.getModelId(); + + if (samples.size() < this.numMinSamples) { + // we cannot get last RCF score since cold start happens asynchronously + coldStart(modelId, entity, detectorId, modelState, detector, listener); + } else { + try { + trainModelFromDataSegments(samples, entity, modelState, detector.getShingleSize()); + listener.onResponse(null); + } catch (Exception e) { + listener.onFailure(e); + } + } + + }, listener::onFailure)); + } + + public void trainModelFromExistingSamples(ModelState modelState, int shingleSize) { + if (modelState == null || modelState.getModel() == null || modelState.getModel().getSamples() == null) { + return; + } + + EntityModel model = modelState.getModel(); + Queue samples = model.getSamples(); + if (samples.size() >= this.numMinSamples) { + try { + trainModelFromDataSegments(samples, model.getEntity().orElse(null), modelState, shingleSize); + } catch (Exception e) { + // e.g., exception from rcf. We can do nothing except logging the error + // We won't retry training for the same entity in the cooldown period + // (60 detector intervals). + logger.error("Unexpected training error", e); + } + + } + } + + /** + * Extract training data and put them into ModelState + * + * @param coldstartDatapoints training data generated from cold start + * @param modelId model Id + * @param modelState entity State + */ + private void extractTrainSamples(List coldstartDatapoints, String modelId, ModelState modelState) { + if (coldstartDatapoints == null || coldstartDatapoints.size() == 0 || modelState == null) { + return; + } + + EntityModel model = modelState.getModel(); + if (model == null) { + model = new EntityModel(null, new ArrayDeque<>(), null); + modelState.setModel(model); + } + + Queue newSamples = new ArrayDeque<>(); + for (double[][] consecutivePoints : coldstartDatapoints) { + for (int i = 0; i < consecutivePoints.length; i++) { + newSamples.add(consecutivePoints[i]); + } + } + + model.setSamples(newSamples); + } + + @Override + public void maintenance() { + doorKeepers.entrySet().stream().forEach(doorKeeperEntry -> { + String detectorId = doorKeeperEntry.getKey(); + DoorKeeper doorKeeper = doorKeeperEntry.getValue(); + if (doorKeeper.expired(modelTtl)) { + doorKeepers.remove(detectorId); + } else { + doorKeeper.maintenance(); + } + }); + } + + @Override + public void clear(String detectorId) { + doorKeepers.remove(detectorId); + } +} diff --git a/src/main/java/org/opensearch/ad/ml/EntityModel.java-e b/src/main/java/org/opensearch/ad/ml/EntityModel.java-e new file mode 100644 index 000000000..348ad8c6e --- /dev/null +++ b/src/main/java/org/opensearch/ad/ml/EntityModel.java-e @@ -0,0 +1,92 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import java.util.ArrayDeque; +import java.util.Optional; +import java.util.Queue; + +import org.opensearch.timeseries.model.Entity; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public class EntityModel { + private Entity entity; + // TODO: sample should record timestamp + private Queue samples; + + private ThresholdedRandomCutForest trcf; + + /** + * Constructor with TRCF. + * + * @param entity entity if any + * @param samples samples with the model + * @param trcf thresholded rcf model + */ + public EntityModel(Entity entity, Queue samples, ThresholdedRandomCutForest trcf) { + this.entity = entity; + this.samples = samples; + this.trcf = trcf; + } + + /** + * In old checkpoint mapping, we don't have entity. It's fine we are missing + * entity as it is mostly used for debugging. + * @return entity + */ + public Optional getEntity() { + return Optional.ofNullable(entity); + } + + public Queue getSamples() { + return this.samples; + } + + public void setSamples(Queue samples) { + this.samples = samples; + } + + public void addSample(double[] sample) { + if (this.samples == null) { + this.samples = new ArrayDeque<>(); + } + if (sample != null && sample.length != 0) { + this.samples.add(sample); + } + } + + /** + * Sets an trcf model. + * + * @param trcf an trcf model + */ + public void setTrcf(ThresholdedRandomCutForest trcf) { + this.trcf = trcf; + } + + /** + * Returns optional trcf model. + * + * @return the trcf model or empty + */ + public Optional getTrcf() { + return Optional.ofNullable(this.trcf); + } + + public void clear() { + if (samples != null) { + samples.clear(); + } + trcf = null; + } +} diff --git a/src/main/java/org/opensearch/ad/ml/HybridThresholdingModel.java-e b/src/main/java/org/opensearch/ad/ml/HybridThresholdingModel.java-e new file mode 100644 index 000000000..f4f4bf4b6 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ml/HybridThresholdingModel.java-e @@ -0,0 +1,373 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.commons.math3.special.Erf; +import org.apache.commons.math3.stat.descriptive.SummaryStatistics; + +import com.google.gson.annotations.Expose; +import com.google.gson.annotations.JsonAdapter; +import com.yahoo.sketches.kll.KllFloatsSketch; +import com.yahoo.sketches.kll.KllFloatsSketchIterator; + +/** + * A model for converting raw anomaly scores into anomaly grades. + * + * The hybrid thresholding model combines a log-normal distribution model/CDF as + * well as an empirical model/CDF for determining anomalous scores. The + * log-normal CDF is used to initialize the empirical CDF. This is done because + * the training set often does not include anomalies and predictions need to be + * made as to how often a large anomaly score will occur given the training + * data. The empirical model uses the technique described in "Optimal Quantile + * Approximation in Streams" by Karnin, Lang, and Liberty. The KLL model is + * implemented in {@code KllFloatSketchSerDe}. + * + * @see KllFloatsSketchSerDe + */ +public class HybridThresholdingModel implements ThresholdingModel { + + /** + * The minimum anomaly score for labeling anomalies. + * Scores below this value result into anomaly grade 0. + */ + public static final double MIN_SCORE = 0.4; + + private static final boolean USE_DOUBLE_SIDED_ERROR = true; + private static final double CONFIDENCE = 0.99; + + @Expose + @JsonAdapter(KllFloatsSketchSerDe.class) + private KllFloatsSketch quantileSketch; + private double maxScore; + private int numLogNormalQuantiles; + private double minPvalueThreshold; + private int downsampleNumSamples; + private long downsampleMaxNumObservations; + + /** + * Initializes a HybridThresholdingModel. + * + * The primary parameters to a HybridThresholdingModel are {@code + * minPvalueThreshold} and {@code maxRankError}. These two parameters define + * the p-value threshold at which anomalies are reported and how accurately + * this threshold can be measured. Note that in order to accurately measure + * the given minimum p-value threshold the maximum rank error needs to be + * small enough. In particular, {@code maxRankError} should be less than + * {@code minPvalueThreshold}. + * + * The maximum possible anomaly score is provided so that large anomaly + * scores can be estimated during training; even when they are not available + * in the training set. The quantile sketch is initialized using {@code + * numLogNormalQuantiles} quantile results from fitting a log-normal + * distribution to the training data. + * + * The size of the empirical CDF, as modeled by the KLL algorithm, is + * determined by the parameter, {@code maxRankError}. When a certain number + * of updates/observations have been processed, as set by {@code + * downsampleMaxNumObservations}, the ECDF model will be automatically + * downsampled to a number of scores equal to {@code downsampleNumSamples}. + * + * @param minPvalueThreshold the p-value threshold beyond which an + * anomaly score is classified as an + * panomaly. A value of 0.995 is + * recommended. (Between: 0 and 1) + * @param maxRankError desired maximum double-sided normalized + * rank error to use in the quantile + * sketch approximation. A value of 0.0001 + * is recommended. + * @param maxScore the largest observable anomaly score + * @param numLogNormalQuantiles number of quantiles of the log-normal + * distribution to compute training. + * (Min: 0) + * @param downsampleNumSamples the number of scores to keep when + * downsampling the model. A value of + * 10_000 is recommended. + * @param downsampleMaxNumObservations the threshold number of observations / + * updates at which the model will be + * automatically downsampled. A value of + * 1_000_000 is recommended. + * @throws IllegalArgumentException if {@code minPvalueThreshold} is not + * strictly between 0 and 1, if {@code + * maxRankError} is larger than 1 - + * {@code minPvalueThreshold}, or if + * {@code numLogNormalQuantiles} + * negative + * + * @see KllFloatsSketchSerDe + */ + public HybridThresholdingModel( + double minPvalueThreshold, + double maxRankError, + double maxScore, + int numLogNormalQuantiles, + int downsampleNumSamples, + long downsampleMaxNumObservations + ) { + if ((minPvalueThreshold <= 0.0) || (1.0 <= minPvalueThreshold)) { + throw new IllegalArgumentException("minPvalueThreshold must be strictly between 0 and 1."); + } + if (maxRankError > (1.0 - minPvalueThreshold)) { + throw new IllegalArgumentException( + "maxRankError must be smaller than 1 - minPvalueThreshold in order to accurately " + "estimate that threshold." + ); + } + if (maxRankError <= 0.0) { + throw new IllegalArgumentException("maxRankError must be positive."); + } + if (maxScore <= 0.0) { + throw new IllegalArgumentException("maxScore must be positive."); + } + if (numLogNormalQuantiles < 0) { + throw new IllegalArgumentException("The maximum number of log-normal quantiles to compute must be non-negative."); + } + if (downsampleNumSamples <= 1) { + throw new IllegalArgumentException("Number of downsamples must be greater than one."); + } + if (downsampleNumSamples >= downsampleMaxNumObservations) { + throw new IllegalArgumentException( + "The number of samples to downsample to must be less than the number of observations " + "before downsampling is triggered." + ); + } + + this.minPvalueThreshold = minPvalueThreshold; + this.quantileSketch = new KllFloatsSketch(KllFloatsSketch.getKFromEpsilon(maxRankError, USE_DOUBLE_SIDED_ERROR)); + this.maxScore = maxScore; + this.numLogNormalQuantiles = numLogNormalQuantiles; + this.downsampleNumSamples = downsampleNumSamples; + this.downsampleMaxNumObservations = downsampleMaxNumObservations; + } + + /** + * Empty constructor only for serialization purpose - DO NOT USE. + * + * WARNING. All clients should avoid using this constructor + * for the objects from this constructor have undefined behaviors. + * This constructor is exclusively used for serialization. + */ + public HybridThresholdingModel() {} + + /** + * Returns the minimum p-value threshold for anomaly classification. + * + * @return minPvalueThreshold + */ + public double getMinPvalueThreshold() { + return minPvalueThreshold; + } + + /** + * Returns the approximate double-sided normalized rank error of the quantile sketch. + * + * @return MaxRankError + */ + public double getMaxRankError() { + return quantileSketch.getNormalizedRankError(USE_DOUBLE_SIDED_ERROR); + } + + /** + * Returns the maximum possible anomaly score of the thresholding model. + * + * @return maxScore + */ + public double getMaxScore() { + return maxScore; + } + + /** + * Returns the number of log-normal quantiles used to initialize the + * quantile sketch. + * + * @return numLogNormalQuantiles + */ + public int getNumLogNormalQuantiles() { + return numLogNormalQuantiles; + } + + /** + * Returns the number of samples to retain when downsampling the model. + * + * @return downsampleNumSamples + */ + public int getDownsampleNumSamples() { + return downsampleNumSamples; + } + + /** + * Returns the number of observations that triggers a model downsampling. + * + * @return downsampleMaxNumObservations + */ + public long getDownsampleMaxNumObservations() { + return downsampleMaxNumObservations; + } + + /** + * Initializes the model using a training set of anomaly scores. + * + * The hybrid model initialization has several steps. First, a log-normal + * distribution is fit to the training set scores. Next, the quantile sketch + * is initialized with at {@code numLogNormalQuantiles} samples from the + * log-normal model up to {@code maxScore}. + * + * @param anomalyScores an array of anomaly scores with which to train the model. + */ + @Override + public void train(double[] anomalyScores) { + /* + We assume the anomaly scores are fit to a log-normal distribution. + Equivalent to fitting a Gaussian to the logs of the anomaly scores. + */ + SummaryStatistics stats = new SummaryStatistics(); + for (int i = 0; i < anomalyScores.length; i++) { + stats.addValue(Math.log(anomalyScores[i])); + } + final double mu = stats.getMean(); + final double sigma = stats.getStandardDeviation(); + + /* + Compute the 1/R quantiles for R = `numLogNormalQuantiles` of the + corresponding log-normal distribution and use these to initialize the + model. We only compute p-values up to the p-value of the known maximum + possible score. Finally, we do not compute the p=0.0 quantile because + raw anomaly scores are positive and non-zero. + */ + final double maxScorePvalue = computeLogNormalCdf(maxScore, mu, sigma); + final double pvalueStep = maxScorePvalue / (numLogNormalQuantiles + 1.0); + for (double pvalue = pvalueStep; pvalue < maxScorePvalue; pvalue += pvalueStep) { + double currentScore = computeLogNormalQuantile(pvalue, mu, sigma); + update(currentScore); + } + } + + /** + * The log-normal cumulative distribution function. + * + * Given and anomaly score compute the corresponding p-value. + * + * @param anomalyScore an anomaly score + * @param mu mean parameter of the log-normal distribution + * @param sigma standard deviation of the log-normal distribution + * @return the p-value of the input anomaly score + */ + private double computeLogNormalCdf(double anomalyScore, double mu, double sigma) { + return (1.0 + Erf.erf((Math.log(anomalyScore) - mu) / (Math.sqrt(2.0) * sigma))) / 2.0; + } + + /** + * The log-normal quantile function. + * + * Given a p-value and log-normal distribution parameters compute the + * corresponding anomaly score. + * + * @param pvalue a p-value between 0 and 1 + * @param mu mean parameter of the log-normal distribution + * @param sigma standard deviation of the log-normal distribution + * @return anomaly score at the given p-value quantile + */ + private double computeLogNormalQuantile(double pvalue, double mu, double sigma) { + return Math.exp(mu + Math.sqrt(2.0) * sigma * Erf.erfInv(2.0 * pvalue - 1.0)); + } + + /** + * Updates the model with a new anomaly score. + * + * Note that once we initialize the hybrid model we only update the + * empirical CDF. The model is downsampled when the total number of + * observations/updates exceeds {@code downsampleMaxNumObservations}. + * + * @param anomalyScore an anomaly score. + * @see HybridThresholdingModel + */ + @Override + public void update(double anomalyScore) { + quantileSketch.update((float) anomalyScore); + + long totalNumObservations = quantileSketch.getN(); + if (totalNumObservations >= downsampleMaxNumObservations) { + downsample(); + } + } + + /** + * Computes the anomaly grade associated with the given anomaly score. A + * non-zero grade implies that the given score is anomalous. The magnitude + * of the grade, a value between 0 and 1, indicates the severity of the + * anomaly. + * + * @param anomalyScore an anomaly score + * @return the associated anomaly grade + */ + @Override + public double grade(double anomalyScore) { + double anomalyGrade = 0.; + if (anomalyScore > MIN_SCORE) { + final double scale = 1.0 / (1.0 - minPvalueThreshold); + final double pvalue = quantileSketch.getRank((float) anomalyScore); + anomalyGrade = scale * (pvalue - minPvalueThreshold); + anomalyGrade = Double.isNaN(anomalyGrade) ? 0. : Math.max(0., anomalyGrade); + } + return anomalyGrade; + } + + /** + * Returns the confidence of the model in predicting anomaly grades; that + * is, the probability that the reported anomaly grade is correct according + * to the underlying model. + * + * For the HybridThresholdingModel the model confidence is from underlying Sketch. + * + * @return the model confidence. + * @see + */ + @Override + public double confidence() { + return CONFIDENCE; + } + + /** + * Replaces the model's ECDF sketch with a downsampled version. + * + * Periodic downsampling of the sketch is primarily useful for allowing the + * model to more easily adapt to changes in the score distribution. This is + * because, with fewer retained points in the sketch, new scores will have a + * larger impact on the distribution. A secondary benefit to the + * downsampling process is to prevent out of memory errors; although the + * memory requirements grows like log(log(N)) there is the chance of an + * allocation bug which we wish to mitigate. + * + * Uses the initialization parameter {@code downsampleNumSamples}. + * + */ + private void downsample() { + KllFloatsSketch downsampledQuantileSketch = new KllFloatsSketch(quantileSketch.getK()); + double pvalueStep = 1.0 / (downsampleNumSamples - 1.0); + for (double pvalue = 0.0; pvalue < 1.0; pvalue += pvalueStep) { + float score = quantileSketch.getQuantile(pvalue); + downsampledQuantileSketch.update(score); + } + downsampledQuantileSketch.update((float) maxScore); + this.quantileSketch = downsampledQuantileSketch; + } + + @Override + public List extractScores() { + KllFloatsSketchIterator iter = quantileSketch.iterator(); + List scores = new ArrayList<>(); + while (iter.next()) { + scores.add((double) iter.getValue()); + } + return scores; + } +} diff --git a/src/main/java/org/opensearch/ad/ml/KllFloatsSketchSerDe.java-e b/src/main/java/org/opensearch/ad/ml/KllFloatsSketchSerDe.java-e new file mode 100644 index 000000000..a6664603e --- /dev/null +++ b/src/main/java/org/opensearch/ad/ml/KllFloatsSketchSerDe.java-e @@ -0,0 +1,42 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import java.lang.reflect.Type; +import java.util.Base64; + +import com.google.gson.JsonDeserializationContext; +import com.google.gson.JsonDeserializer; +import com.google.gson.JsonElement; +import com.google.gson.JsonPrimitive; +import com.google.gson.JsonSerializationContext; +import com.google.gson.JsonSerializer; +import com.yahoo.memory.Memory; +import com.yahoo.sketches.kll.KllFloatsSketch; + +/** + * Serializes/deserailizes KllFloatsSketch. + * + * A sketch is serialized to a byte array and then encoded in Base64. + */ +public class KllFloatsSketchSerDe implements JsonSerializer, JsonDeserializer { + + @Override + public JsonElement serialize(KllFloatsSketch src, Type typeOfSrc, JsonSerializationContext context) { + return new JsonPrimitive(Base64.getEncoder().encodeToString(src.toByteArray())); + } + + @Override + public KllFloatsSketch deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context) { + return KllFloatsSketch.heapify(Memory.wrap(Base64.getDecoder().decode(json.getAsString()))); + } +} diff --git a/src/main/java/org/opensearch/ad/ml/ModelManager.java-e b/src/main/java/org/opensearch/ad/ml/ModelManager.java-e new file mode 100644 index 000000000..464297193 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ml/ModelManager.java-e @@ -0,0 +1,827 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.DetectorModelSize; +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.util.DateUtils; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + +import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.config.Precision; +import com.amazon.randomcutforest.parkservices.AnomalyDescriptor; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * A facade managing ML operations and models. + */ +public class ModelManager implements DetectorModelSize { + protected static final String ENTITY_SAMPLE = "sp"; + protected static final String ENTITY_RCF = "rcf"; + protected static final String ENTITY_THRESHOLD = "th"; + + public enum ModelType { + RCF("rcf"), + THRESHOLD("threshold"), + ENTITY("entity"); + + private String name; + + ModelType(String name) { + this.name = name; + } + + public String getName() { + return name; + } + } + + private static final Logger logger = LogManager.getLogger(ModelManager.class); + + // states + private TRCFMemoryAwareConcurrentHashmap forests; + private Map> thresholds; + + // configuration + private final int rcfNumTrees; + private final int rcfNumSamplesInTree; + private final double rcfTimeDecay; + private final int rcfNumMinSamples; + private final double thresholdMinPvalue; + private final int minPreviewSize; + private final Duration modelTtl; + private Duration checkpointInterval; + + // dependencies + private final CheckpointDao checkpointDao; + private final Clock clock; + public FeatureManager featureManager; + + private EntityColdStarter entityColdStarter; + private MemoryTracker memoryTracker; + + private final double initialAcceptFraction; + + /** + * Constructor. + * + * @param checkpointDao model checkpoint storage + * @param clock clock for system time + * @param rcfNumTrees number of trees used in RCF + * @param rcfNumSamplesInTree number of samples in a RCF tree + * @param rcfTimeDecay time decay for RCF + * @param rcfNumMinSamples minimum samples for RCF to score + * @param thresholdMinPvalue min P-value for thresholding + * @param minPreviewSize minimum number of data points for preview + * @param modelTtl time to live for hosted models + * @param checkpointIntervalSetting setting of interval between checkpoints + * @param entityColdStarter HCAD cold start utility + * @param featureManager Used to create features for models + * @param memoryTracker AD memory usage tracker + * @param settings Node settings + * @param clusterService Cluster service accessor + */ + public ModelManager( + CheckpointDao checkpointDao, + Clock clock, + int rcfNumTrees, + int rcfNumSamplesInTree, + double rcfTimeDecay, + int rcfNumMinSamples, + double thresholdMinPvalue, + int minPreviewSize, + Duration modelTtl, + Setting checkpointIntervalSetting, + EntityColdStarter entityColdStarter, + FeatureManager featureManager, + MemoryTracker memoryTracker, + Settings settings, + ClusterService clusterService + ) { + this.checkpointDao = checkpointDao; + this.clock = clock; + this.rcfNumTrees = rcfNumTrees; + this.rcfNumSamplesInTree = rcfNumSamplesInTree; + this.rcfTimeDecay = rcfTimeDecay; + this.rcfNumMinSamples = rcfNumMinSamples; + this.thresholdMinPvalue = thresholdMinPvalue; + this.minPreviewSize = minPreviewSize; + this.modelTtl = modelTtl; + this.checkpointInterval = DateUtils.toDuration(checkpointIntervalSetting.get(settings)); + if (clusterService != null) { + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(checkpointIntervalSetting, it -> this.checkpointInterval = DateUtils.toDuration(it)); + } + + this.forests = new TRCFMemoryAwareConcurrentHashmap<>(memoryTracker); + this.thresholds = new ConcurrentHashMap<>(); + + this.entityColdStarter = entityColdStarter; + this.featureManager = featureManager; + this.memoryTracker = memoryTracker; + this.initialAcceptFraction = rcfNumMinSamples * 1.0d / rcfNumSamplesInTree; + } + + /** + * Returns to listener the RCF anomaly result using the specified model. + * + * @param detectorId ID of the detector + * @param modelId ID of the model to score the point + * @param point features of the data point + * @param listener onResponse is called with RCF result for the input point, including a score + * onFailure is called with ResourceNotFoundException when the model is not found + * onFailure is called with LimitExceededException when a limit is exceeded for the model + */ + public void getTRcfResult(String detectorId, String modelId, double[] point, ActionListener listener) { + if (forests.containsKey(modelId)) { + getTRcfResult(forests.get(modelId), point, listener); + } else { + checkpointDao + .getTRCFModel( + modelId, + ActionListener + .wrap( + restoredModel -> processRestoredTRcf(restoredModel, modelId, detectorId, point, listener), + listener::onFailure + ) + ); + } + } + + private void getTRcfResult( + ModelState modelState, + double[] point, + ActionListener listener + ) { + modelState.setLastUsedTime(clock.instant()); + + ThresholdedRandomCutForest trcf = modelState.getModel(); + try { + AnomalyDescriptor result = trcf.process(point, 0); + double[] attribution = normalizeAttribution(trcf.getForest(), result.getRelevantAttribution()); + listener + .onResponse( + new ThresholdingResult( + result.getAnomalyGrade(), + result.getDataConfidence(), + result.getRCFScore(), + result.getTotalUpdates(), + result.getRelativeIndex(), + attribution, + result.getPastValues(), + result.getExpectedValuesList(), + result.getLikelihoodOfValues(), + result.getThreshold(), + result.getNumberOfTrees() + ) + ); + } catch (Exception e) { + listener.onFailure(e); + } + } + + /** + * normalize total attribution to 1 + * + * @param forest rcf accessor + * @param rawAttribution raw attribution scores. Can be null when + * 1) the anomaly grade is 0; + * 2) there are missing values and we are using differenced transforms. + * Read RCF's ImputePreprocessor.postProcess. + * + * @return normalized attribution + */ + public double[] normalizeAttribution(RandomCutForest forest, double[] rawAttribution) { + if (forest == null) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Empty forest")); + } + // rawAttribution is null when anomaly grade is less than or equals to 0 + // need to create an empty array for bwc because the old node expects an non-empty array + double[] attribution = createEmptyAttribution(forest); + if (rawAttribution != null && rawAttribution.length > 0) { + double sum = Arrays.stream(rawAttribution).sum(); + // avoid dividing by zero error + if (sum > 0) { + if (rawAttribution.length != attribution.length) { + throw new IllegalArgumentException( + String + .format( + Locale.ROOT, + "Unexpected attribution array length: expected %d but is %d", + attribution.length, + rawAttribution.length + ) + ); + } + int numFeatures = rawAttribution.length; + attribution = new double[numFeatures]; + for (int i = 0; i < numFeatures; i++) { + attribution[i] = rawAttribution[i] / sum; + } + } + } + + return attribution; + } + + private double[] createEmptyAttribution(RandomCutForest forest) { + int shingleSize = forest.getShingleSize(); + if (shingleSize <= 0) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "zero shingle size")); + } + int baseDimensions = forest.getDimensions() / shingleSize; + return new double[baseDimensions]; + } + + private Optional> restoreModelState( + Optional rcfModel, + String modelId, + String detectorId + ) { + if (!rcfModel.isPresent()) { + return Optional.empty(); + } + return rcfModel + .filter(rcf -> memoryTracker.isHostingAllowed(detectorId, rcf)) + .map(rcf -> ModelState.createSingleEntityModelState(rcf, modelId, detectorId, ModelType.RCF.getName(), clock)); + } + + private void processRestoredTRcf( + Optional rcfModel, + String modelId, + String detectorId, + double[] point, + ActionListener listener + ) { + Optional> model = restoreModelState(rcfModel, modelId, detectorId); + if (model.isPresent()) { + forests.put(modelId, model.get()); + getTRcfResult(model.get(), point, listener); + } else { + throw new ResourceNotFoundException(detectorId, ADCommonMessages.NO_CHECKPOINT_ERR_MSG + modelId); + } + } + + /** + * Process rcf checkpoint for total rcf updates polling + * @param checkpointModel rcf model restored from its checkpoint + * @param modelId model Id + * @param detectorId detector Id + * @param listener listener to return total updates of rcf + */ + private void processRestoredCheckpoint( + Optional checkpointModel, + String modelId, + String detectorId, + ActionListener listener + ) { + logger.info("Restoring checkpoint for {}", modelId); + Optional> model = restoreModelState(checkpointModel, modelId, detectorId); + if (model.isPresent()) { + forests.put(modelId, model.get()); + if (model.get().getModel() != null && model.get().getModel().getForest() != null) + listener.onResponse(model.get().getModel().getForest().getTotalUpdates()); + } else { + listener.onFailure(new ResourceNotFoundException(detectorId, ADCommonMessages.NO_CHECKPOINT_ERR_MSG + modelId)); + } + } + + /** + * Returns to listener the result using the specified thresholding model. + * + * @param detectorId ID of the detector + * @param modelId ID of the thresholding model + * @param score raw anomaly score + * @param listener onResponse is called with the thresholding model result for the raw score + * onFailure is called with ResourceNotFoundException when the model is not found + */ + public void getThresholdingResult(String detectorId, String modelId, double score, ActionListener listener) { + if (thresholds.containsKey(modelId)) { + getThresholdingResult(thresholds.get(modelId), score, listener); + } else { + checkpointDao + .getThresholdModel( + modelId, + ActionListener + .wrap(model -> processThresholdCheckpoint(model, modelId, detectorId, score, listener), listener::onFailure) + ); + } + } + + private void getThresholdingResult( + ModelState modelState, + double score, + ActionListener listener + ) { + ThresholdingModel threshold = modelState.getModel(); + double grade = threshold.grade(score); + double confidence = threshold.confidence(); + if (score > 0) { + threshold.update(score); + } + modelState.setLastUsedTime(clock.instant()); + listener.onResponse(new ThresholdingResult(grade, confidence, score)); + } + + private void processThresholdCheckpoint( + Optional thresholdModel, + String modelId, + String detectorId, + double score, + ActionListener listener + ) { + Optional> model = thresholdModel + .map( + threshold -> ModelState.createSingleEntityModelState(threshold, modelId, detectorId, ModelType.THRESHOLD.getName(), clock) + ); + if (model.isPresent()) { + thresholds.put(modelId, model.get()); + getThresholdingResult(model.get(), score, listener); + } else { + throw new ResourceNotFoundException(detectorId, ADCommonMessages.NO_CHECKPOINT_ERR_MSG + modelId); + } + } + + /** + * Gets ids of all hosted models. + * + * @return ids of all hosted models. + */ + public Set getAllModelIds() { + return Stream.of(forests.keySet(), thresholds.keySet()).flatMap(set -> set.stream()).collect(Collectors.toSet()); + } + + /** + * Gets modelStates of all model partitions hosted on a node + * + * @return list of modelStates + */ + public List> getAllModels() { + return Stream.concat(forests.values().stream(), thresholds.values().stream()).collect(Collectors.toList()); + } + + /** + * Stops hosting the model and creates a checkpoint. + * + * Used when adding a OpenSearch node. We have to stop all models because + * requests for those model ids would be sent to other nodes. If we don't stop + * them, there would be memory leak. + * + * @param detectorId ID of the detector + * @param modelId ID of the model to stop hosting + * @param listener onResponse is called with null when the operation is completed + */ + public void stopModel(String detectorId, String modelId, ActionListener listener) { + logger.info(String.format(Locale.ROOT, "Stopping detector %s model %s", detectorId, modelId)); + stopModel(forests, modelId, ActionListener.wrap(r -> stopModel(thresholds, modelId, listener), listener::onFailure)); + } + + private void stopModel(Map> models, String modelId, ActionListener listener) { + Instant now = clock.instant(); + Optional> modelState = Optional + .ofNullable(models.remove(modelId)) + .filter(model -> model.getLastCheckpointTime().plus(checkpointInterval).isBefore(now)); + if (modelState.isPresent()) { + T model = modelState.get().getModel(); + if (model instanceof ThresholdedRandomCutForest) { + checkpointDao + .putTRCFCheckpoint( + modelId, + (ThresholdedRandomCutForest) model, + ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure) + ); + } else if (model instanceof ThresholdingModel) { + checkpointDao + .putThresholdCheckpoint( + modelId, + (ThresholdingModel) model, + ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure) + ); + } else { + listener.onFailure(new IllegalArgumentException("Unexpected model type")); + } + } else { + listener.onResponse(null); + } + } + + /** + * Permanently deletes models hosted in memory and persisted in index. + * When stop realtime job, will call this method to clear all model cache + * and checkpoints. + * + * @param detectorId id the of the detector for which models are to be permanently deleted + * @param listener onResponse is called with null when this operation is completed + */ + public void clear(String detectorId, ActionListener listener) { + clearModels(detectorId, forests, ActionListener.wrap(r -> clearModels(detectorId, thresholds, listener), listener::onFailure)); + } + + private void clearModels(String detectorId, Map models, ActionListener listener) { + Iterator id = models.keySet().iterator(); + clearModelForIterator(detectorId, models, id, listener); + } + + private void clearModelForIterator(String detectorId, Map models, Iterator idIter, ActionListener listener) { + if (idIter.hasNext()) { + String modelId = idIter.next(); + if (SingleStreamModelIdMapper.getDetectorIdForModelId(modelId).equals(detectorId)) { + models.remove(modelId); + checkpointDao + .deleteModelCheckpoint( + modelId, + ActionListener.wrap(r -> clearModelForIterator(detectorId, models, idIter, listener), listener::onFailure) + ); + } else { + clearModelForIterator(detectorId, models, idIter, listener); + } + } else { + listener.onResponse(null); + } + } + + /** + * Trains and saves cold-start AD models. + * + * This implementations splits RCF models and trains them all. + * As all model partitions have the same size, the scores from RCF models are merged by averaging. + * Since RCF outputs 0 until it is ready, initial 0 scores are meaningless and therefore filtered out. + * Filtered (non-zero) RCF scores are the training data for a single thresholding model. + * All trained models are serialized and persisted to be hosted. + * + * @param anomalyDetector the detector for which models are trained + * @param dataPoints M, N shape, where M is the number of samples for training and N is the number of features + * @param listener onResponse is called with null when this operation is completed + * onFailure is called IllegalArgumentException when training data is invalid + * onFailure is called LimitExceededException when a limit for training is exceeded + */ + public void trainModel(AnomalyDetector anomalyDetector, double[][] dataPoints, ActionListener listener) { + if (dataPoints.length == 0 || dataPoints[0].length == 0) { + listener.onFailure(new IllegalArgumentException("Data points must not be empty.")); + } else { + int rcfNumFeatures = dataPoints[0].length; + try { + trainModelForStep(anomalyDetector, dataPoints, rcfNumFeatures, 0, listener); + } catch (Exception e) { + listener.onFailure(e); + } + } + } + + private void trainModelForStep( + AnomalyDetector detector, + double[][] dataPoints, + int rcfNumFeatures, + int step, + ActionListener listener + ) { + ThresholdedRandomCutForest trcf = ThresholdedRandomCutForest + .builder() + .dimensions(rcfNumFeatures) + .sampleSize(rcfNumSamplesInTree) + .numberOfTrees(rcfNumTrees) + .timeDecay(rcfTimeDecay) + .outputAfter(rcfNumMinSamples) + .initialAcceptFraction(initialAcceptFraction) + .parallelExecutionEnabled(false) + .compact(true) + .precision(Precision.FLOAT_32) + .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) + .shingleSize(detector.getShingleSize()) + .anomalyRate(1 - thresholdMinPvalue) + .build(); + Arrays.stream(dataPoints).forEach(s -> trcf.process(s, 0)); + + String modelId = SingleStreamModelIdMapper.getRcfModelId(detector.getId(), step); + checkpointDao.putTRCFCheckpoint(modelId, trcf, ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure)); + } + + /** + * Does model maintenance. + * + * The implementation makes checkpoints for hosted models and stops hosting models not recently used. + * + * @param listener onResponse is called with null when this operation is completed. + */ + public void maintenance(ActionListener listener) { + maintenanceForIterator( + forests, + forests.entrySet().iterator(), + ActionListener.wrap(r -> maintenanceForIterator(thresholds, thresholds.entrySet().iterator(), listener), listener::onFailure) + ); + } + + private void maintenanceForIterator( + Map> models, + Iterator>> iter, + ActionListener listener + ) { + if (iter.hasNext()) { + Entry> modelEntry = iter.next(); + String modelId = modelEntry.getKey(); + ModelState modelState = modelEntry.getValue(); + Instant now = clock.instant(); + if (modelState.expired(modelTtl)) { + models.remove(modelId); + } + if (modelState.getLastCheckpointTime().plus(checkpointInterval).isBefore(now)) { + ActionListener checkpointListener = ActionListener.wrap(r -> { + modelState.setLastCheckpointTime(now); + maintenanceForIterator(models, iter, listener); + }, e -> { + logger.warn("Failed to finish maintenance for model id " + modelId, e); + maintenanceForIterator(models, iter, listener); + }); + T model = modelState.getModel(); + if (model instanceof ThresholdedRandomCutForest) { + checkpointDao.putTRCFCheckpoint(modelId, (ThresholdedRandomCutForest) model, checkpointListener); + } else if (model instanceof ThresholdingModel) { + checkpointDao.putThresholdCheckpoint(modelId, (ThresholdingModel) model, checkpointListener); + } else { + checkpointListener.onFailure(new IllegalArgumentException("Unexpected model type")); + } + } else { + maintenanceForIterator(models, iter, listener); + } + } else { + listener.onResponse(null); + } + } + + /** + * Returns computed anomaly results for preview data points. + * + * @param dataPoints features of preview data points + * @param shingleSize model shingle size + * @return thresholding results of preview data points + * @throws IllegalArgumentException when preview data points are not valid + */ + public List getPreviewResults(double[][] dataPoints, int shingleSize) { + if (dataPoints.length < minPreviewSize) { + throw new IllegalArgumentException("Insufficient data for preview results. Minimum required: " + minPreviewSize); + } + // Train RCF models and collect non-zero scores + int rcfNumFeatures = dataPoints[0].length; + // speed is important in preview. We don't want cx to wait too long. + // thus use the default value of boundingBoxCacheFraction = 1 + ThresholdedRandomCutForest trcf = ThresholdedRandomCutForest + .builder() + .randomSeed(0L) + .dimensions(rcfNumFeatures) + .sampleSize(rcfNumSamplesInTree) + .numberOfTrees(rcfNumTrees) + .timeDecay(rcfTimeDecay) + .outputAfter(rcfNumMinSamples) + .initialAcceptFraction(initialAcceptFraction) + .parallelExecutionEnabled(false) + .compact(true) + .precision(Precision.FLOAT_32) + .boundingBoxCacheFraction(AnomalyDetectorSettings.BATCH_BOUNDING_BOX_CACHE_RATIO) + .shingleSize(shingleSize) + .anomalyRate(1 - this.thresholdMinPvalue) + .build(); + return Arrays.stream(dataPoints).map(point -> { + AnomalyDescriptor descriptor = trcf.process(point, 0); + return new ThresholdingResult( + descriptor.getAnomalyGrade(), + descriptor.getDataConfidence(), + descriptor.getRCFScore(), + descriptor.getTotalUpdates(), + descriptor.getRelativeIndex(), + normalizeAttribution(trcf.getForest(), descriptor.getRelevantAttribution()), + descriptor.getPastValues(), + descriptor.getExpectedValuesList(), + descriptor.getLikelihoodOfValues(), + descriptor.getThreshold(), + rcfNumTrees + ); + }).collect(Collectors.toList()); + } + + /** + * Get all RCF partition's size corresponding to a detector. Thresholding models' size is a constant since they are small in size (KB). + * @param detectorId detector id + * @return a map of model id to its memory size + */ + @Override + public Map getModelSize(String detectorId) { + Map res = new HashMap<>(); + forests + .entrySet() + .stream() + .filter(entry -> SingleStreamModelIdMapper.getDetectorIdForModelId(entry.getKey()).equals(detectorId)) + .forEach(entry -> { res.put(entry.getKey(), memoryTracker.estimateTRCFModelSize(entry.getValue().getModel())); }); + thresholds + .entrySet() + .stream() + .filter(entry -> SingleStreamModelIdMapper.getDetectorIdForModelId(entry.getKey()).equals(detectorId)) + .forEach(entry -> { res.put(entry.getKey(), (long) memoryTracker.getThresholdModelBytes()); }); + return res; + } + + /** + * Get a RCF model's total updates. + * @param modelId the RCF model's id + * @param detectorId detector Id + * @param listener listener to return the result + */ + public void getTotalUpdates(String modelId, String detectorId, ActionListener listener) { + ModelState model = forests.get(modelId); + if (model != null) { + if (model.getModel() != null && model.getModel().getForest() != null) { + listener.onResponse(model.getModel().getForest().getTotalUpdates()); + } else { + listener.onResponse(0L); + } + } else { + checkpointDao + .getTRCFModel( + modelId, + ActionListener + .wrap(checkpoint -> processRestoredCheckpoint(checkpoint, modelId, detectorId, listener), listener::onFailure) + ); + } + } + + /** + * Compute anomaly result for the given data point + * @param datapoint Data point + * @param modelState the state associated with the entity + * @param modelId the model Id + * @param entity entity accessor + * @param shingleSize Shingle size + * + * @return anomaly result, confidence, and the corresponding RCF score. + */ + public ThresholdingResult getAnomalyResultForEntity( + double[] datapoint, + ModelState modelState, + String modelId, + Entity entity, + int shingleSize + ) { + ThresholdingResult result = new ThresholdingResult(0, 0, 0); + if (modelState != null) { + EntityModel entityModel = modelState.getModel(); + + if (entityModel == null) { + entityModel = new EntityModel(entity, new ArrayDeque<>(), null); + modelState.setModel(entityModel); + } + + if (!entityModel.getTrcf().isPresent()) { + entityColdStarter.trainModelFromExistingSamples(modelState, shingleSize); + } + + if (entityModel.getTrcf().isPresent()) { + result = score(datapoint, modelId, modelState); + } else { + entityModel.addSample(datapoint); + } + } + return result; + } + + public ThresholdingResult score(double[] feature, String modelId, ModelState modelState) { + ThresholdingResult result = new ThresholdingResult(0, 0, 0); + EntityModel model = modelState.getModel(); + try { + if (model != null && model.getTrcf().isPresent()) { + ThresholdedRandomCutForest trcf = model.getTrcf().get(); + Optional.ofNullable(model.getSamples()).ifPresent(q -> { + q.stream().forEach(s -> trcf.process(s, 0)); + q.clear(); + }); + result = toResult(trcf.getForest(), trcf.process(feature, 0)); + } + } catch (Exception e) { + logger + .error( + new ParameterizedMessage( + "Fail to score for [{}]: model Id [{}], feature [{}]", + modelState.getModel().getEntity(), + modelId, + Arrays.toString(feature) + ), + e + ); + throw e; + } finally { + modelState.setLastUsedTime(clock.instant()); + } + return result; + } + + /** + * Instantiate an entity state out of checkpoint. Train models if there are + * enough samples. + * @param checkpoint Checkpoint loaded from index + * @param entity objects to access Entity attributes + * @param modelId Model Id + * @param detectorId Detector Id + * @param shingleSize Shingle size + * + * @return updated model state + * + */ + public ModelState processEntityCheckpoint( + Optional> checkpoint, + Entity entity, + String modelId, + String detectorId, + int shingleSize + ) { + // entity state to instantiate + ModelState modelState = new ModelState<>( + new EntityModel(entity, new ArrayDeque<>(), null), + modelId, + detectorId, + ModelType.ENTITY.getName(), + clock, + 0 + ); + + if (checkpoint.isPresent()) { + Entry modelToTime = checkpoint.get(); + EntityModel restoredModel = modelToTime.getKey(); + combineSamples(modelState.getModel(), restoredModel); + modelState.setModel(restoredModel); + modelState.setLastCheckpointTime(modelToTime.getValue()); + } + EntityModel model = modelState.getModel(); + if (model == null) { + model = new EntityModel(null, new ArrayDeque<>(), null); + modelState.setModel(model); + } + + if (!model.getTrcf().isPresent() && model.getSamples() != null && model.getSamples().size() >= rcfNumMinSamples) { + entityColdStarter.trainModelFromExistingSamples(modelState, shingleSize); + } + return modelState; + } + + private void combineSamples(EntityModel fromModel, EntityModel toModel) { + Queue samples = fromModel.getSamples(); + while (samples.peek() != null) { + toModel.addSample(samples.poll()); + } + } + + private ThresholdingResult toResult(RandomCutForest rcf, AnomalyDescriptor anomalyDescriptor) { + return new ThresholdingResult( + anomalyDescriptor.getAnomalyGrade(), + anomalyDescriptor.getDataConfidence(), + anomalyDescriptor.getRCFScore(), + anomalyDescriptor.getTotalUpdates(), + anomalyDescriptor.getRelativeIndex(), + normalizeAttribution(rcf, anomalyDescriptor.getRelevantAttribution()), + anomalyDescriptor.getPastValues(), + anomalyDescriptor.getExpectedValuesList(), + anomalyDescriptor.getLikelihoodOfValues(), + anomalyDescriptor.getThreshold(), + rcfNumTrees + ); + } +} diff --git a/src/main/java/org/opensearch/ad/ml/ModelState.java-e b/src/main/java/org/opensearch/ad/ml/ModelState.java-e new file mode 100644 index 000000000..9e909bc58 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ml/ModelState.java-e @@ -0,0 +1,212 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; + +import org.opensearch.ad.ExpiringState; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.timeseries.constant.CommonName; + +/** + * A ML model and states such as usage. + */ +public class ModelState implements ExpiringState { + + public static String MODEL_TYPE_KEY = "model_type"; + public static String LAST_USED_TIME_KEY = "last_used_time"; + public static String LAST_CHECKPOINT_TIME_KEY = "last_checkpoint_time"; + public static String PRIORITY_KEY = "priority"; + private T model; + private String modelId; + private String detectorId; + private String modelType; + // time when the ML model was used last time + private Instant lastUsedTime; + private Instant lastCheckpointTime; + private Clock clock; + private float priority; + + /** + * Constructor. + * + * @param model ML model + * @param modelId Id of model partition + * @param detectorId Id of detector this model partition is used for + * @param modelType type of model + * @param clock UTC clock + * @param priority Priority of the model state. Used in multi-entity detectors' cache. + */ + public ModelState(T model, String modelId, String detectorId, String modelType, Clock clock, float priority) { + this.model = model; + this.modelId = modelId; + this.detectorId = detectorId; + this.modelType = modelType; + this.lastUsedTime = clock.instant(); + // this is inaccurate until we find the last checkpoint time from disk + this.lastCheckpointTime = Instant.MIN; + this.clock = clock; + this.priority = priority; + } + + /** + * Create state with zero priority. Used in single-entity detector. + * + * @param Model object's type + * @param model The actual model object + * @param modelId Model Id + * @param detectorId Detector Id + * @param modelType Model type like RCF model + * @param clock UTC clock + * + * @return the created model state + */ + public static ModelState createSingleEntityModelState( + T model, + String modelId, + String detectorId, + String modelType, + Clock clock + ) { + return new ModelState<>(model, modelId, detectorId, modelType, clock, 0f); + } + + /** + * Returns the ML model. + * + * @return the ML model. + */ + public T getModel() { + return this.model; + } + + public void setModel(T model) { + this.model = model; + } + + /** + * Gets the model ID + * + * @return modelId of model + */ + public String getModelId() { + return modelId; + } + + /** + * Gets the detectorID of the model + * + * @return detectorId associated with the model + */ + public String getId() { + return detectorId; + } + + /** + * Gets the type of the model + * + * @return modelType of the model + */ + public String getModelType() { + return modelType; + } + + /** + * Returns the time when the ML model was last used. + * + * @return the time when the ML model was last used + */ + public Instant getLastUsedTime() { + return this.lastUsedTime; + } + + /** + * Sets the time when ML model was last used. + * + * @param lastUsedTime time when the ML model was used last time + */ + public void setLastUsedTime(Instant lastUsedTime) { + this.lastUsedTime = lastUsedTime; + } + + /** + * Returns the time when a checkpoint for the ML model was made last time. + * + * @return the time when a checkpoint for the ML model was made last time. + */ + public Instant getLastCheckpointTime() { + return this.lastCheckpointTime; + } + + /** + * Sets the time when a checkpoint for the ML model was made last time. + * + * @param lastCheckpointTime time when a checkpoint for the ML model was made last time. + */ + public void setLastCheckpointTime(Instant lastCheckpointTime) { + this.lastCheckpointTime = lastCheckpointTime; + } + + /** + * Returns priority of the ModelState + * @return the priority + */ + public float getPriority() { + return priority; + } + + public void setPriority(float priority) { + this.priority = priority; + } + + /** + * Gets the Model State as a map + * + * @return Map of ModelStates + */ + public Map getModelStateAsMap() { + return new HashMap() { + { + put(CommonName.MODEL_ID_FIELD, modelId); + put(ADCommonName.DETECTOR_ID_KEY, detectorId); + put(MODEL_TYPE_KEY, modelType); + /* A stats API broadcasts requests to all nodes and renders node responses using toXContent. + * + * For the local node, the stats API's calls toXContent on the node response directly. + * For remote node, the coordinating node gets a serialized content from + * ADStatsNodeResponse.writeTo, deserializes the content, and renders the result using toXContent. + * Since ADStatsNodeResponse.writeTo uses StreamOutput::writeGenericValue, we can only use + * a long instead of the Instant object itself as + * StreamOutput::writeGenericValue only recognizes built-in types.*/ + put(LAST_USED_TIME_KEY, lastUsedTime.toEpochMilli()); + if (lastCheckpointTime != Instant.MIN) { + put(LAST_CHECKPOINT_TIME_KEY, lastCheckpointTime.toEpochMilli()); + } + if (model != null && model instanceof EntityModel) { + EntityModel summary = (EntityModel) model; + if (summary.getEntity().isPresent()) { + put(CommonName.ENTITY_KEY, summary.getEntity().get().toStat()); + } + } + } + }; + } + + @Override + public boolean expired(Duration stateTtl) { + return expired(lastUsedTime, stateTtl, clock.instant()); + } +} diff --git a/src/main/java/org/opensearch/ad/ml/SingleStreamModelIdMapper.java-e b/src/main/java/org/opensearch/ad/ml/SingleStreamModelIdMapper.java-e new file mode 100644 index 000000000..ac3ce899d --- /dev/null +++ b/src/main/java/org/opensearch/ad/ml/SingleStreamModelIdMapper.java-e @@ -0,0 +1,76 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import java.util.Locale; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Utilities to map between single-stream models and ids. We will have circular + * dependency between ModelManager and CheckpointDao if we put these functions inside + * ModelManager. + * + */ +public class SingleStreamModelIdMapper { + protected static final String DETECTOR_ID_PATTERN = "(.*)_model_.+"; + protected static final String RCF_MODEL_ID_PATTERN = "%s_model_rcf_%d"; + protected static final String THRESHOLD_MODEL_ID_PATTERN = "%s_model_threshold"; + + /** + * Returns the model ID for the RCF model partition. + * + * @param detectorId ID of the detector for which the RCF model is trained + * @param partitionNumber number of the partition + * @return ID for the RCF model partition + */ + public static String getRcfModelId(String detectorId, int partitionNumber) { + return String.format(Locale.ROOT, RCF_MODEL_ID_PATTERN, detectorId, partitionNumber); + } + + /** + * Returns the model ID for the thresholding model. + * + * @param detectorId ID of the detector for which the thresholding model is trained + * @return ID for the thresholding model + */ + public static String getThresholdModelId(String detectorId) { + return String.format(Locale.ROOT, THRESHOLD_MODEL_ID_PATTERN, detectorId); + } + + /** + * Gets the detector id from the model id. + * + * @param modelId id of a model + * @return id of the detector the model is for + * @throws IllegalArgumentException if model id is invalid + */ + public static String getDetectorIdForModelId(String modelId) { + Matcher matcher = Pattern.compile(DETECTOR_ID_PATTERN).matcher(modelId); + if (matcher.matches()) { + return matcher.group(1); + } else { + throw new IllegalArgumentException("Invalid model id " + modelId); + } + } + + /** + * Returns the model ID for the thresholding model according to the input + * rcf model id. + * @param rcfModelId RCF model id + * @return thresholding model Id + */ + public static String getThresholdModelIdFromRCFModelId(String rcfModelId) { + String detectorId = getDetectorIdForModelId(rcfModelId); + return getThresholdModelId(detectorId); + } +} diff --git a/src/main/java/org/opensearch/ad/ml/TRCFMemoryAwareConcurrentHashmap.java-e b/src/main/java/org/opensearch/ad/ml/TRCFMemoryAwareConcurrentHashmap.java-e new file mode 100644 index 000000000..7b7b1fe7d --- /dev/null +++ b/src/main/java/org/opensearch/ad/ml/TRCFMemoryAwareConcurrentHashmap.java-e @@ -0,0 +1,54 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import java.util.concurrent.ConcurrentHashMap; + +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.MemoryTracker.Origin; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * A customized ConcurrentHashMap that can automatically consume and release memory. + * This enables minimum change to our single-entity code as we just have to replace + * the map implementation. + * + * Note: this is mainly used for single-entity detectors. + */ +public class TRCFMemoryAwareConcurrentHashmap extends ConcurrentHashMap> { + private final MemoryTracker memoryTracker; + + public TRCFMemoryAwareConcurrentHashmap(MemoryTracker memoryTracker) { + this.memoryTracker = memoryTracker; + } + + @Override + public ModelState remove(Object key) { + ModelState deletedModelState = super.remove(key); + if (deletedModelState != null && deletedModelState.getModel() != null) { + long memoryToRelease = memoryTracker.estimateTRCFModelSize(deletedModelState.getModel()); + memoryTracker.releaseMemory(memoryToRelease, true, Origin.SINGLE_ENTITY_DETECTOR); + } + return deletedModelState; + } + + @Override + public ModelState put(K key, ModelState value) { + ModelState previousAssociatedState = super.put(key, value); + if (value != null && value.getModel() != null) { + long memoryToConsume = memoryTracker.estimateTRCFModelSize(value.getModel()); + memoryTracker.consumeMemory(memoryToConsume, true, Origin.SINGLE_ENTITY_DETECTOR); + } + return previousAssociatedState; + } +} diff --git a/src/main/java/org/opensearch/ad/ml/ThresholdingModel.java-e b/src/main/java/org/opensearch/ad/ml/ThresholdingModel.java-e new file mode 100644 index 000000000..6b9a488de --- /dev/null +++ b/src/main/java/org/opensearch/ad/ml/ThresholdingModel.java-e @@ -0,0 +1,68 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import java.util.List; + +/** + * A model for converting raw anomaly scores into anomaly grades. + * + * A thresholding model is trained on a set of raw anomaly scores like those + * output from the Random Cut Forest algorithm. The fundamental assumption of + * anomaly scores is that the larger the score the more anomalous the + * corresponding data point. Based on this training set an internal threshold is + * computed to determine if a given score is anomalous. The thresholding model + * can be updated with new anomaly scores such as in a streaming context. + * + */ +public interface ThresholdingModel { + + /** + * Initializes the model using a training set of anomaly scores. + * + * @param anomalyScores array of anomaly scores with which to train the model + */ + void train(double[] anomalyScores); + + /** + * Update the model with a new anomaly score. + * + * @param anomalyScore an anomaly score + */ + void update(double anomalyScore); + + /** + * Computes the anomaly grade associated with the given anomaly score. A + * non-zero grade implies that the given score is anomalous. The magnitude + * of the grade, a value between 0 and 1, indicates the severity of the + * anomaly. + * + * @param anomalyScore an anomaly score + * @return the associated anomaly grade + */ + double grade(double anomalyScore); + + /** + * Returns the confidence of the model in predicting anomaly grades; that + * is, the probability that the reported anomaly grade is correct according + * to the underlying model. + * + * @return the model confidence + */ + double confidence(); + + /** + * Extract scores + * @return the extract scores + */ + List extractScores(); +} diff --git a/src/main/java/org/opensearch/ad/ml/ThresholdingResult.java-e b/src/main/java/org/opensearch/ad/ml/ThresholdingResult.java-e new file mode 100644 index 000000000..a2da03f51 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ml/ThresholdingResult.java-e @@ -0,0 +1,337 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import org.apache.commons.lang.builder.ToStringBuilder; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.timeseries.ml.IntermediateResult; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.FeatureData; + +/** + * Data object containing thresholding results. + */ +public class ThresholdingResult extends IntermediateResult { + + private final double grade; + /** + * position of the anomaly vis a vis the current time (can be -ve) if anomaly is + * detected late, which can and should happen sometime; for shingle size 1; this + * is always 0. + * + * For example, current shingle is + [ + 6819.0, + 2375.3333333333335, + 0.0, + 49882.0, + 92070.0, + 5084.0, + 2072.809523809524, + 0.0, + 43529.0, + 91169.0, + 8129.0, + 2582.892857142857, + 12.0, + 54241.0, + 84596.0, + 11174.0, + 3092.9761904761904, + 24.0, + 64952.0, + 78024.0, + 14220.0, + 3603.059523809524, + 37.0, + 75664.0, + 71451.0, + 17265.0, + 4113.142857142857, + 49.0, + 86376.0, + 64878.0, + 16478.0, + 3761.4166666666665, + 37.0, + 78990.0, + 70057.0, + 15691.0, + 3409.690476190476, + 24.0, + 71604.0, + 75236.0 + ], + * If rcf returns relativeIndex is -2, baseDimension is 5, we look back baseDimension * 2 and get the + * culprit input that triggers anomaly: + [17265.0, + 4113.142857142857, + 49.0, + 86376.0, + 64878.0 + ], + */ + private int relativeIndex; + + // a flattened version denoting the basic contribution of each input variable + private double[] relevantAttribution; + + // pastValues is related to relativeIndex and startOfAnomaly. Read the same + // field comment on AnomalyResult. + private double[] pastValues; + + /* + * The expected value is only calculated for anomalous detection intervals, + * and will generate expected value for each feature if detector has multiple + * features. + * Currently we expect one set of expected values. In the future, we + * might give different expected values with differently likelihood. So + * the two-dimensional array allows us to future-proof our applications. + * Also, expected values correspond to pastValues if present or current input + * point otherwise. If pastValues is present, we can add a text on UX to explain + * we found an anomaly from the past. + Example: + "expected_value": [{ + "likelihood": 0.8, + "value_list": [{ + "feature_id": "blah", + "value": 1 + }, + { + "feature_id": "blah2", + "value": 1 + } + ] + }]*/ + private double[][] expectedValuesList; + + // likelihood values for the list. + // There will be one likelihood value that spans a single set of expected values. + // For now, only one likelihood value should be expected as there is only + // one set of expected values. + private double[] likelihoodOfValues; + + // rcf score threshold at the time of writing a result + private double threshold; + + // size of the forest + private int forestSize; + + protected final double confidence; + + /** + * Constructor for default empty value or backward compatibility. + * In terms of bwc, when an old node sends request for threshold results, + * we need to return only what they understand. + * + * @param grade anomaly grade + * @param confidence confidence for the grade + * @param rcfScore rcf score associated with the grade and confidence. Used + * by multi-entity detector to differentiate whether the result is worth + * saving or not. + */ + public ThresholdingResult(double grade, double confidence, double rcfScore) { + this(grade, confidence, rcfScore, 0, 0, null, null, null, null, 0, 0); + } + + public ThresholdingResult( + double grade, + double confidence, + double rcfScore, + long totalUpdates, + int relativeIndex, + double[] relevantAttribution, + double[] pastValues, + double[][] expectedValuesList, + double[] likelihoodOfValues, + double threshold, + int forestSize + ) { + super(totalUpdates, rcfScore); + this.confidence = confidence; + this.grade = grade; + + this.relativeIndex = relativeIndex; + this.relevantAttribution = relevantAttribution; + this.pastValues = pastValues; + this.expectedValuesList = expectedValuesList; + this.likelihoodOfValues = likelihoodOfValues; + this.threshold = threshold; + this.forestSize = forestSize; + } + + /** + * Returns the confidence for the result (e.g., anomaly grade in AD). + * + * @return confidence for the result + */ + public double getConfidence() { + return confidence; + } + + /** + * Returns the anomaly grade. + * + * @return the anoamly grade + */ + public double getGrade() { + return grade; + } + + public int getRelativeIndex() { + return relativeIndex; + } + + public double[] getRelevantAttribution() { + return relevantAttribution; + } + + public double[] getPastValues() { + return pastValues; + } + + public double[][] getExpectedValuesList() { + return expectedValuesList; + } + + public double[] getLikelihoodOfValues() { + return likelihoodOfValues; + } + + public double getThreshold() { + return threshold; + } + + public int getForestSize() { + return forestSize; + } + + @Override + public boolean equals(Object o) { + if (!super.equals(o)) + return false; + if (getClass() != o.getClass()) + return false; + ThresholdingResult that = (ThresholdingResult) o; + return Double.doubleToLongBits(confidence) == Double.doubleToLongBits(that.confidence) + && Double.doubleToLongBits(this.grade) == Double.doubleToLongBits(that.grade) + && this.relativeIndex == that.relativeIndex + && Arrays.equals(relevantAttribution, that.relevantAttribution) + && Arrays.equals(pastValues, that.pastValues) + && Arrays.deepEquals(expectedValuesList, that.expectedValuesList) + && Arrays.equals(likelihoodOfValues, that.likelihoodOfValues) + && Double.doubleToLongBits(threshold) == Double.doubleToLongBits(that.threshold) + && forestSize == that.forestSize; + } + + @Override + public int hashCode() { + return Objects + .hash( + super.hashCode(), + confidence, + grade, + relativeIndex, + Arrays.hashCode(relevantAttribution), + Arrays.hashCode(pastValues), + Arrays.deepHashCode(expectedValuesList), + Arrays.hashCode(likelihoodOfValues), + threshold, + forestSize + ); + } + + @Override + public String toString() { + return new ToStringBuilder(this) + .append(super.toString()) + .append("grade", grade) + .append("confidence", confidence) + .append("relativeIndex", relativeIndex) + .append("relevantAttribution", Arrays.toString(relevantAttribution)) + .append("pastValues", Arrays.toString(pastValues)) + .append("expectedValuesList", Arrays.deepToString(expectedValuesList)) + .append("likelihoodOfValues", Arrays.toString(likelihoodOfValues)) + .append("threshold", threshold) + .append("forestSize", forestSize) + .toString(); + } + + /** + * + * Convert ThresholdingResult to AnomalyResult + * + * @param detector Detector config + * @param dataStartInstant data start time + * @param dataEndInstant data end time + * @param executionStartInstant execution start time + * @param executionEndInstant execution end time + * @param featureData Feature data list + * @param entity Entity attributes + * @param schemaVersion Schema version + * @param modelId Model Id + * @param taskId Task Id + * @param error Error + * @return converted AnomalyResult + */ + @Override + public List toIndexableResults( + Config detector, + Instant dataStartInstant, + Instant dataEndInstant, + Instant executionStartInstant, + Instant executionEndInstant, + List featureData, + Optional entity, + Integer schemaVersion, + String modelId, + String taskId, + String error + ) { + return Collections + .singletonList( + AnomalyResult + .fromRawTRCFResult( + detector.getId(), + detector.getIntervalInMilliseconds(), + taskId, + rcfScore, + grade, + confidence, + featureData, + dataStartInstant, + dataEndInstant, + executionStartInstant, + executionEndInstant, + error, + entity, + detector.getUser(), + schemaVersion, + modelId, + relevantAttribution, + relativeIndex, + pastValues, + expectedValuesList, + likelihoodOfValues, + threshold + ) + ); + } +} diff --git a/src/main/java/org/opensearch/ad/model/ADEntityTaskProfile.java b/src/main/java/org/opensearch/ad/model/ADEntityTaskProfile.java index 90eccb019..3d473d0e2 100644 --- a/src/main/java/org/opensearch/ad/model/ADEntityTaskProfile.java +++ b/src/main/java/org/opensearch/ad/model/ADEntityTaskProfile.java @@ -11,14 +11,14 @@ package org.opensearch.ad.model; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; import java.util.Objects; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; diff --git a/src/main/java/org/opensearch/ad/model/ADEntityTaskProfile.java-e b/src/main/java/org/opensearch/ad/model/ADEntityTaskProfile.java-e new file mode 100644 index 000000000..3d473d0e2 --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/ADEntityTaskProfile.java-e @@ -0,0 +1,308 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Objects; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.model.Entity; + +/** + * HC detector's entity task profile. + */ +public class ADEntityTaskProfile implements ToXContentObject, Writeable { + + public static final String SHINGLE_SIZE_FIELD = "shingle_size"; + public static final String RCF_TOTAL_UPDATES_FIELD = "rcf_total_updates"; + public static final String THRESHOLD_MODEL_TRAINED_FIELD = "threshold_model_trained"; + public static final String THRESHOLD_MODEL_TRAINING_DATA_SIZE_FIELD = "threshold_model_training_data_size"; + public static final String MODEL_SIZE_IN_BYTES = "model_size_in_bytes"; + public static final String NODE_ID_FIELD = "node_id"; + public static final String ENTITY_FIELD = "entity"; + public static final String TASK_ID_FIELD = "task_id"; + public static final String AD_TASK_TYPE_FIELD = "task_type"; + + private Integer shingleSize; + private Long rcfTotalUpdates; + private Boolean thresholdModelTrained; + private Integer thresholdModelTrainingDataSize; + private Long modelSizeInBytes; + private String nodeId; + private Entity entity; + private String taskId; + private String adTaskType; + + public ADEntityTaskProfile( + Integer shingleSize, + Long rcfTotalUpdates, + Boolean thresholdModelTrained, + Integer thresholdModelTrainingDataSize, + Long modelSizeInBytes, + String nodeId, + Entity entity, + String taskId, + String adTaskType + ) { + this.shingleSize = shingleSize; + this.rcfTotalUpdates = rcfTotalUpdates; + this.thresholdModelTrained = thresholdModelTrained; + this.thresholdModelTrainingDataSize = thresholdModelTrainingDataSize; + this.modelSizeInBytes = modelSizeInBytes; + this.nodeId = nodeId; + this.entity = entity; + this.taskId = taskId; + this.adTaskType = adTaskType; + } + + public static ADEntityTaskProfile parse(XContentParser parser) throws IOException { + Integer shingleSize = null; + Long rcfTotalUpdates = null; + Boolean thresholdModelTrained = null; + Integer thresholdModelTrainingDataSize = null; + Long modelSizeInBytes = null; + String nodeId = null; + Entity entity = null; + String taskId = null; + String taskType = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case SHINGLE_SIZE_FIELD: + shingleSize = parser.intValue(); + break; + case RCF_TOTAL_UPDATES_FIELD: + rcfTotalUpdates = parser.longValue(); + break; + case THRESHOLD_MODEL_TRAINED_FIELD: + thresholdModelTrained = parser.booleanValue(); + break; + case THRESHOLD_MODEL_TRAINING_DATA_SIZE_FIELD: + thresholdModelTrainingDataSize = parser.intValue(); + break; + case MODEL_SIZE_IN_BYTES: + modelSizeInBytes = parser.longValue(); + break; + case NODE_ID_FIELD: + nodeId = parser.text(); + break; + case ENTITY_FIELD: + entity = Entity.parse(parser); + break; + case TASK_ID_FIELD: + taskId = parser.text(); + break; + case AD_TASK_TYPE_FIELD: + taskType = parser.text(); + break; + default: + parser.skipChildren(); + break; + } + } + return new ADEntityTaskProfile( + shingleSize, + rcfTotalUpdates, + thresholdModelTrained, + thresholdModelTrainingDataSize, + modelSizeInBytes, + nodeId, + entity, + taskId, + taskType + ); + } + + public ADEntityTaskProfile(StreamInput input) throws IOException { + this.shingleSize = input.readOptionalInt(); + this.rcfTotalUpdates = input.readOptionalLong(); + this.thresholdModelTrained = input.readOptionalBoolean(); + this.thresholdModelTrainingDataSize = input.readOptionalInt(); + this.modelSizeInBytes = input.readOptionalLong(); + this.nodeId = input.readOptionalString(); + if (input.readBoolean()) { + this.entity = new Entity(input); + } else { + this.entity = null; + } + this.taskId = input.readOptionalString(); + this.adTaskType = input.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalInt(shingleSize); + out.writeOptionalLong(rcfTotalUpdates); + out.writeOptionalBoolean(thresholdModelTrained); + out.writeOptionalInt(thresholdModelTrainingDataSize); + out.writeOptionalLong(modelSizeInBytes); + out.writeOptionalString(nodeId); + if (entity != null) { + out.writeBoolean(true); + entity.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeOptionalString(taskId); + out.writeOptionalString(adTaskType); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + if (shingleSize != null) { + xContentBuilder.field(SHINGLE_SIZE_FIELD, shingleSize); + } + if (rcfTotalUpdates != null) { + xContentBuilder.field(RCF_TOTAL_UPDATES_FIELD, rcfTotalUpdates); + } + if (thresholdModelTrained != null) { + xContentBuilder.field(THRESHOLD_MODEL_TRAINED_FIELD, thresholdModelTrained); + } + if (thresholdModelTrainingDataSize != null) { + xContentBuilder.field(THRESHOLD_MODEL_TRAINING_DATA_SIZE_FIELD, thresholdModelTrainingDataSize); + } + if (modelSizeInBytes != null) { + xContentBuilder.field(MODEL_SIZE_IN_BYTES, modelSizeInBytes); + } + if (nodeId != null) { + xContentBuilder.field(NODE_ID_FIELD, nodeId); + } + if (entity != null) { + xContentBuilder.field(ENTITY_FIELD, entity); + } + if (taskId != null) { + xContentBuilder.field(TASK_ID_FIELD, taskId); + } + if (adTaskType != null) { + xContentBuilder.field(AD_TASK_TYPE_FIELD, adTaskType); + } + return xContentBuilder.endObject(); + } + + public Integer getShingleSize() { + return shingleSize; + } + + public void setShingleSize(Integer shingleSize) { + this.shingleSize = shingleSize; + } + + public Long getRcfTotalUpdates() { + return rcfTotalUpdates; + } + + public void setRcfTotalUpdates(Long rcfTotalUpdates) { + this.rcfTotalUpdates = rcfTotalUpdates; + } + + public Boolean getThresholdModelTrained() { + return thresholdModelTrained; + } + + public void setThresholdModelTrained(Boolean thresholdModelTrained) { + this.thresholdModelTrained = thresholdModelTrained; + } + + public Integer getThresholdModelTrainingDataSize() { + return thresholdModelTrainingDataSize; + } + + public void setThresholdModelTrainingDataSize(Integer thresholdModelTrainingDataSize) { + this.thresholdModelTrainingDataSize = thresholdModelTrainingDataSize; + } + + public Long getModelSizeInBytes() { + return modelSizeInBytes; + } + + public void setModelSizeInBytes(Long modelSizeInBytes) { + this.modelSizeInBytes = modelSizeInBytes; + } + + public String getNodeId() { + return nodeId; + } + + public void setNodeId(String nodeId) { + this.nodeId = nodeId; + } + + public Entity getEntity() { + return entity; + } + + public void setEntity(Entity entity) { + this.entity = entity; + } + + public String getTaskId() { + return taskId; + } + + public void setTaskId(String taskId) { + this.taskId = taskId; + } + + public String getAdTaskType() { + return adTaskType; + } + + public void setAdTaskType(String adTaskType) { + this.adTaskType = adTaskType; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + ADEntityTaskProfile that = (ADEntityTaskProfile) o; + return Objects.equals(shingleSize, that.shingleSize) + && Objects.equals(rcfTotalUpdates, that.rcfTotalUpdates) + && Objects.equals(thresholdModelTrained, that.thresholdModelTrained) + && Objects.equals(thresholdModelTrainingDataSize, that.thresholdModelTrainingDataSize) + && Objects.equals(modelSizeInBytes, that.modelSizeInBytes) + && Objects.equals(nodeId, that.nodeId) + && Objects.equals(taskId, that.taskId) + && Objects.equals(adTaskType, that.adTaskType) + && Objects.equals(entity, that.entity); + } + + @Override + public int hashCode() { + return Objects + .hash( + shingleSize, + rcfTotalUpdates, + thresholdModelTrained, + thresholdModelTrainingDataSize, + modelSizeInBytes, + nodeId, + entity, + taskId, + adTaskType + ); + } +} diff --git a/src/main/java/org/opensearch/ad/model/ADTask.java b/src/main/java/org/opensearch/ad/model/ADTask.java index 7eb9fe73c..0004f9640 100644 --- a/src/main/java/org/opensearch/ad/model/ADTask.java +++ b/src/main/java/org/opensearch/ad/model/ADTask.java @@ -12,15 +12,15 @@ package org.opensearch.ad.model; import static org.opensearch.ad.model.ADTaskState.NOT_ENDED_STATES; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; import java.time.Instant; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; diff --git a/src/main/java/org/opensearch/ad/model/ADTask.java-e b/src/main/java/org/opensearch/ad/model/ADTask.java-e new file mode 100644 index 000000000..50d24965a --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/ADTask.java-e @@ -0,0 +1,797 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import static org.opensearch.ad.model.ADTaskState.NOT_ENDED_STATES; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.time.Instant; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.annotation.Generated; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.util.ParseUtils; + +import com.google.common.base.Objects; + +/** + * One anomaly detection task means one detector starts to run until stopped. + */ +public class ADTask implements ToXContentObject, Writeable { + + public static final String TASK_ID_FIELD = "task_id"; + public static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; + public static final String STARTED_BY_FIELD = "started_by"; + public static final String STOPPED_BY_FIELD = "stopped_by"; + public static final String ERROR_FIELD = "error"; + public static final String STATE_FIELD = "state"; + public static final String DETECTOR_ID_FIELD = "detector_id"; + public static final String TASK_PROGRESS_FIELD = "task_progress"; + public static final String INIT_PROGRESS_FIELD = "init_progress"; + public static final String CURRENT_PIECE_FIELD = "current_piece"; + public static final String EXECUTION_START_TIME_FIELD = "execution_start_time"; + public static final String EXECUTION_END_TIME_FIELD = "execution_end_time"; + public static final String IS_LATEST_FIELD = "is_latest"; + public static final String TASK_TYPE_FIELD = "task_type"; + public static final String CHECKPOINT_ID_FIELD = "checkpoint_id"; + public static final String COORDINATING_NODE_FIELD = "coordinating_node"; + public static final String WORKER_NODE_FIELD = "worker_node"; + public static final String DETECTOR_FIELD = "detector"; + public static final String DETECTION_DATE_RANGE_FIELD = "detection_date_range"; + public static final String ENTITY_FIELD = "entity"; + public static final String PARENT_TASK_ID_FIELD = "parent_task_id"; + public static final String ESTIMATED_MINUTES_LEFT_FIELD = "estimated_minutes_left"; + public static final String USER_FIELD = "user"; + public static final String HISTORICAL_TASK_PREFIX = "HISTORICAL"; + + private String taskId = null; + private Instant lastUpdateTime = null; + private String startedBy = null; + private String stoppedBy = null; + private String error = null; + private String state = null; + private String detectorId = null; + private Float taskProgress = null; + private Float initProgress = null; + private Instant currentPiece = null; + private Instant executionStartTime = null; + private Instant executionEndTime = null; + private Boolean isLatest = null; + private String taskType = null; + private String checkpointId = null; + private AnomalyDetector detector = null; + + private String coordinatingNode = null; + private String workerNode = null; + private DateRange detectionDateRange = null; + private Entity entity = null; + private String parentTaskId = null; + private Integer estimatedMinutesLeft = null; + private User user = null; + + private ADTask() {} + + public ADTask(StreamInput input) throws IOException { + this.taskId = input.readOptionalString(); + this.taskType = input.readOptionalString(); + this.detectorId = input.readOptionalString(); + if (input.readBoolean()) { + this.detector = new AnomalyDetector(input); + } else { + this.detector = null; + } + this.state = input.readOptionalString(); + this.taskProgress = input.readOptionalFloat(); + this.initProgress = input.readOptionalFloat(); + this.currentPiece = input.readOptionalInstant(); + this.executionStartTime = input.readOptionalInstant(); + this.executionEndTime = input.readOptionalInstant(); + this.isLatest = input.readOptionalBoolean(); + this.error = input.readOptionalString(); + this.checkpointId = input.readOptionalString(); + this.lastUpdateTime = input.readOptionalInstant(); + this.startedBy = input.readOptionalString(); + this.stoppedBy = input.readOptionalString(); + this.coordinatingNode = input.readOptionalString(); + this.workerNode = input.readOptionalString(); + if (input.readBoolean()) { + this.user = new User(input); + } else { + user = null; + } + // Below are new fields added since AD 1.1 + if (input.available() > 0) { + if (input.readBoolean()) { + this.detectionDateRange = new DateRange(input); + } else { + this.detectionDateRange = null; + } + if (input.readBoolean()) { + this.entity = new Entity(input); + } else { + this.entity = null; + } + this.parentTaskId = input.readOptionalString(); + this.estimatedMinutesLeft = input.readOptionalInt(); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(taskId); + out.writeOptionalString(taskType); + out.writeOptionalString(detectorId); + if (detector != null) { + out.writeBoolean(true); + detector.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeOptionalString(state); + out.writeOptionalFloat(taskProgress); + out.writeOptionalFloat(initProgress); + out.writeOptionalInstant(currentPiece); + out.writeOptionalInstant(executionStartTime); + out.writeOptionalInstant(executionEndTime); + out.writeOptionalBoolean(isLatest); + out.writeOptionalString(error); + out.writeOptionalString(checkpointId); + out.writeOptionalInstant(lastUpdateTime); + out.writeOptionalString(startedBy); + out.writeOptionalString(stoppedBy); + out.writeOptionalString(coordinatingNode); + out.writeOptionalString(workerNode); + if (user != null) { + out.writeBoolean(true); // user exists + user.writeTo(out); + } else { + out.writeBoolean(false); // user does not exist + } + // Only forward AD task to nodes with same version, so it's ok to write these new fields. + if (detectionDateRange != null) { + out.writeBoolean(true); + detectionDateRange.writeTo(out); + } else { + out.writeBoolean(false); + } + if (entity != null) { + out.writeBoolean(true); + entity.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeOptionalString(parentTaskId); + out.writeOptionalInt(estimatedMinutesLeft); + } + + public static Builder builder() { + return new Builder(); + } + + public boolean isHistoricalTask() { + return taskType.startsWith(HISTORICAL_TASK_PREFIX); + } + + public boolean isEntityTask() { + return ADTaskType.HISTORICAL_HC_ENTITY.name().equals(taskType); + } + + /** + * Get detector level task id. If a task has no parent task, the task is detector level task. + * @return detector level task id + */ + public String getDetectorLevelTaskId() { + return getParentTaskId() != null ? getParentTaskId() : getTaskId(); + } + + public boolean isDone() { + return !NOT_ENDED_STATES.contains(this.getState()); + } + + public static class Builder { + private String taskId = null; + private String taskType = null; + private String detectorId = null; + private AnomalyDetector detector = null; + private String state = null; + private Float taskProgress = null; + private Float initProgress = null; + private Instant currentPiece = null; + private Instant executionStartTime = null; + private Instant executionEndTime = null; + private Boolean isLatest = null; + private String error = null; + private String checkpointId = null; + private Instant lastUpdateTime = null; + private String startedBy = null; + private String stoppedBy = null; + private String coordinatingNode = null; + private String workerNode = null; + private DateRange detectionDateRange = null; + private Entity entity = null; + private String parentTaskId; + private Integer estimatedMinutesLeft; + private User user = null; + + public Builder() {} + + public Builder taskId(String taskId) { + this.taskId = taskId; + return this; + } + + public Builder lastUpdateTime(Instant lastUpdateTime) { + this.lastUpdateTime = lastUpdateTime; + return this; + } + + public Builder startedBy(String startedBy) { + this.startedBy = startedBy; + return this; + } + + public Builder stoppedBy(String stoppedBy) { + this.stoppedBy = stoppedBy; + return this; + } + + public Builder error(String error) { + this.error = error; + return this; + } + + public Builder state(String state) { + this.state = state; + return this; + } + + public Builder detectorId(String detectorId) { + this.detectorId = detectorId; + return this; + } + + public Builder taskProgress(Float taskProgress) { + this.taskProgress = taskProgress; + return this; + } + + public Builder initProgress(Float initProgress) { + this.initProgress = initProgress; + return this; + } + + public Builder currentPiece(Instant currentPiece) { + this.currentPiece = currentPiece; + return this; + } + + public Builder executionStartTime(Instant executionStartTime) { + this.executionStartTime = executionStartTime; + return this; + } + + public Builder executionEndTime(Instant executionEndTime) { + this.executionEndTime = executionEndTime; + return this; + } + + public Builder isLatest(Boolean isLatest) { + this.isLatest = isLatest; + return this; + } + + public Builder taskType(String taskType) { + this.taskType = taskType; + return this; + } + + public Builder checkpointId(String checkpointId) { + this.checkpointId = checkpointId; + return this; + } + + public Builder detector(AnomalyDetector detector) { + this.detector = detector; + return this; + } + + public Builder coordinatingNode(String coordinatingNode) { + this.coordinatingNode = coordinatingNode; + return this; + } + + public Builder workerNode(String workerNode) { + this.workerNode = workerNode; + return this; + } + + public Builder detectionDateRange(DateRange detectionDateRange) { + this.detectionDateRange = detectionDateRange; + return this; + } + + public Builder entity(Entity entity) { + this.entity = entity; + return this; + } + + public Builder parentTaskId(String parentTaskId) { + this.parentTaskId = parentTaskId; + return this; + } + + public Builder estimatedMinutesLeft(Integer estimatedMinutesLeft) { + this.estimatedMinutesLeft = estimatedMinutesLeft; + return this; + } + + public Builder user(User user) { + this.user = user; + return this; + } + + public ADTask build() { + ADTask adTask = new ADTask(); + adTask.taskId = this.taskId; + adTask.lastUpdateTime = this.lastUpdateTime; + adTask.error = this.error; + adTask.state = this.state; + adTask.detectorId = this.detectorId; + adTask.taskProgress = this.taskProgress; + adTask.initProgress = this.initProgress; + adTask.currentPiece = this.currentPiece; + adTask.executionStartTime = this.executionStartTime; + adTask.executionEndTime = this.executionEndTime; + adTask.isLatest = this.isLatest; + adTask.taskType = this.taskType; + adTask.checkpointId = this.checkpointId; + adTask.detector = this.detector; + adTask.startedBy = this.startedBy; + adTask.stoppedBy = this.stoppedBy; + adTask.coordinatingNode = this.coordinatingNode; + adTask.workerNode = this.workerNode; + adTask.detectionDateRange = this.detectionDateRange; + adTask.entity = this.entity; + adTask.parentTaskId = this.parentTaskId; + adTask.estimatedMinutesLeft = this.estimatedMinutesLeft; + adTask.user = this.user; + + return adTask; + } + + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + if (taskId != null) { + xContentBuilder.field(TASK_ID_FIELD, taskId); + } + if (lastUpdateTime != null) { + xContentBuilder.field(LAST_UPDATE_TIME_FIELD, lastUpdateTime.toEpochMilli()); + } + if (startedBy != null) { + xContentBuilder.field(STARTED_BY_FIELD, startedBy); + } + if (stoppedBy != null) { + xContentBuilder.field(STOPPED_BY_FIELD, stoppedBy); + } + if (error != null) { + xContentBuilder.field(ERROR_FIELD, error); + } + if (state != null) { + xContentBuilder.field(STATE_FIELD, state); + } + if (detectorId != null) { + xContentBuilder.field(DETECTOR_ID_FIELD, detectorId); + } + if (taskProgress != null) { + xContentBuilder.field(TASK_PROGRESS_FIELD, taskProgress); + } + if (initProgress != null) { + xContentBuilder.field(INIT_PROGRESS_FIELD, initProgress); + } + if (currentPiece != null) { + xContentBuilder.field(CURRENT_PIECE_FIELD, currentPiece.toEpochMilli()); + } + if (executionStartTime != null) { + xContentBuilder.field(EXECUTION_START_TIME_FIELD, executionStartTime.toEpochMilli()); + } + if (executionEndTime != null) { + xContentBuilder.field(EXECUTION_END_TIME_FIELD, executionEndTime.toEpochMilli()); + } + if (isLatest != null) { + xContentBuilder.field(IS_LATEST_FIELD, isLatest); + } + if (taskType != null) { + xContentBuilder.field(TASK_TYPE_FIELD, taskType); + } + if (checkpointId != null) { + xContentBuilder.field(CHECKPOINT_ID_FIELD, checkpointId); + } + if (coordinatingNode != null) { + xContentBuilder.field(COORDINATING_NODE_FIELD, coordinatingNode); + } + if (workerNode != null) { + xContentBuilder.field(WORKER_NODE_FIELD, workerNode); + } + if (detector != null) { + xContentBuilder.field(DETECTOR_FIELD, detector); + } + if (detectionDateRange != null) { + xContentBuilder.field(DETECTION_DATE_RANGE_FIELD, detectionDateRange); + } + if (entity != null) { + xContentBuilder.field(ENTITY_FIELD, entity); + } + if (parentTaskId != null) { + xContentBuilder.field(PARENT_TASK_ID_FIELD, parentTaskId); + } + if (estimatedMinutesLeft != null) { + xContentBuilder.field(ESTIMATED_MINUTES_LEFT_FIELD, estimatedMinutesLeft); + } + if (user != null) { + xContentBuilder.field(USER_FIELD, user); + } + return xContentBuilder.endObject(); + } + + public static ADTask parse(XContentParser parser) throws IOException { + return parse(parser, null); + } + + public static ADTask parse(XContentParser parser, String taskId) throws IOException { + Instant lastUpdateTime = null; + String startedBy = null; + String stoppedBy = null; + String error = null; + String state = null; + String detectorId = null; + Float taskProgress = null; + Float initProgress = null; + Instant currentPiece = null; + Instant executionStartTime = null; + Instant executionEndTime = null; + Boolean isLatest = null; + String taskType = null; + String checkpointId = null; + AnomalyDetector detector = null; + String parsedTaskId = taskId; + String coordinatingNode = null; + String workerNode = null; + DateRange detectionDateRange = null; + Entity entity = null; + String parentTaskId = null; + Integer estimatedMinutesLeft = null; + User user = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case LAST_UPDATE_TIME_FIELD: + lastUpdateTime = ParseUtils.toInstant(parser); + break; + case STARTED_BY_FIELD: + startedBy = parser.text(); + break; + case STOPPED_BY_FIELD: + stoppedBy = parser.text(); + break; + case ERROR_FIELD: + error = parser.text(); + break; + case STATE_FIELD: + state = parser.text(); + break; + case DETECTOR_ID_FIELD: + detectorId = parser.text(); + break; + case TASK_PROGRESS_FIELD: + taskProgress = parser.floatValue(); + break; + case INIT_PROGRESS_FIELD: + initProgress = parser.floatValue(); + break; + case CURRENT_PIECE_FIELD: + currentPiece = ParseUtils.toInstant(parser); + break; + case EXECUTION_START_TIME_FIELD: + executionStartTime = ParseUtils.toInstant(parser); + break; + case EXECUTION_END_TIME_FIELD: + executionEndTime = ParseUtils.toInstant(parser); + break; + case IS_LATEST_FIELD: + isLatest = parser.booleanValue(); + break; + case TASK_TYPE_FIELD: + taskType = parser.text(); + break; + case CHECKPOINT_ID_FIELD: + checkpointId = parser.text(); + break; + case DETECTOR_FIELD: + detector = AnomalyDetector.parse(parser); + break; + case TASK_ID_FIELD: + parsedTaskId = parser.text(); + break; + case COORDINATING_NODE_FIELD: + coordinatingNode = parser.text(); + break; + case WORKER_NODE_FIELD: + workerNode = parser.text(); + break; + case DETECTION_DATE_RANGE_FIELD: + detectionDateRange = DateRange.parse(parser); + break; + case ENTITY_FIELD: + entity = Entity.parse(parser); + break; + case PARENT_TASK_ID_FIELD: + parentTaskId = parser.text(); + break; + case ESTIMATED_MINUTES_LEFT_FIELD: + estimatedMinutesLeft = parser.intValue(); + break; + case USER_FIELD: + user = User.parse(parser); + break; + default: + parser.skipChildren(); + break; + } + } + AnomalyDetector anomalyDetector = detector == null + ? null + : new AnomalyDetector( + detectorId, + detector.getVersion(), + detector.getName(), + detector.getDescription(), + detector.getTimeField(), + detector.getIndices(), + detector.getFeatureAttributes(), + detector.getFilterQuery(), + detector.getInterval(), + detector.getWindowDelay(), + detector.getShingleSize(), + detector.getUiMetadata(), + detector.getSchemaVersion(), + detector.getLastUpdateTime(), + detector.getCategoryFields(), + detector.getUser(), + detector.getCustomResultIndex(), + detector.getImputationOption() + ); + return new Builder() + .taskId(parsedTaskId) + .lastUpdateTime(lastUpdateTime) + .startedBy(startedBy) + .stoppedBy(stoppedBy) + .error(error) + .state(state) + .detectorId(detectorId) + .taskProgress(taskProgress) + .initProgress(initProgress) + .currentPiece(currentPiece) + .executionStartTime(executionStartTime) + .executionEndTime(executionEndTime) + .isLatest(isLatest) + .taskType(taskType) + .checkpointId(checkpointId) + .coordinatingNode(coordinatingNode) + .workerNode(workerNode) + .detector(anomalyDetector) + .detectionDateRange(detectionDateRange) + .entity(entity) + .parentTaskId(parentTaskId) + .estimatedMinutesLeft(estimatedMinutesLeft) + .user(user) + .build(); + } + + @Generated + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + ADTask that = (ADTask) o; + return Objects.equal(getTaskId(), that.getTaskId()) + && Objects.equal(getLastUpdateTime(), that.getLastUpdateTime()) + && Objects.equal(getStartedBy(), that.getStartedBy()) + && Objects.equal(getStoppedBy(), that.getStoppedBy()) + && Objects.equal(getError(), that.getError()) + && Objects.equal(getState(), that.getState()) + && Objects.equal(getId(), that.getId()) + && Objects.equal(getTaskProgress(), that.getTaskProgress()) + && Objects.equal(getInitProgress(), that.getInitProgress()) + && Objects.equal(getCurrentPiece(), that.getCurrentPiece()) + && Objects.equal(getExecutionStartTime(), that.getExecutionStartTime()) + && Objects.equal(getExecutionEndTime(), that.getExecutionEndTime()) + && Objects.equal(getLatest(), that.getLatest()) + && Objects.equal(getTaskType(), that.getTaskType()) + && Objects.equal(getCheckpointId(), that.getCheckpointId()) + && Objects.equal(getCoordinatingNode(), that.getCoordinatingNode()) + && Objects.equal(getWorkerNode(), that.getWorkerNode()) + && Objects.equal(getDetector(), that.getDetector()) + && Objects.equal(getDetectionDateRange(), that.getDetectionDateRange()) + && Objects.equal(getEntity(), that.getEntity()) + && Objects.equal(getParentTaskId(), that.getParentTaskId()) + && Objects.equal(getEstimatedMinutesLeft(), that.getEstimatedMinutesLeft()) + && Objects.equal(getUser(), that.getUser()); + } + + @Generated + @Override + public int hashCode() { + return Objects + .hashCode( + taskId, + lastUpdateTime, + startedBy, + stoppedBy, + error, + state, + detectorId, + taskProgress, + initProgress, + currentPiece, + executionStartTime, + executionEndTime, + isLatest, + taskType, + checkpointId, + coordinatingNode, + workerNode, + detector, + detectionDateRange, + entity, + parentTaskId, + estimatedMinutesLeft, + user + ); + } + + public String getTaskId() { + return taskId; + } + + public void setTaskId(String taskId) { + this.taskId = taskId; + } + + public Instant getLastUpdateTime() { + return lastUpdateTime; + } + + public String getStartedBy() { + return startedBy; + } + + public String getStoppedBy() { + return stoppedBy; + } + + public String getError() { + return error; + } + + public void setError(String error) { + this.error = error; + } + + public String getState() { + return state; + } + + public void setState(String state) { + this.state = state; + } + + public String getId() { + return detectorId; + } + + public Float getTaskProgress() { + return taskProgress; + } + + public Float getInitProgress() { + return initProgress; + } + + public Instant getCurrentPiece() { + return currentPiece; + } + + public Instant getExecutionStartTime() { + return executionStartTime; + } + + public Instant getExecutionEndTime() { + return executionEndTime; + } + + public Boolean getLatest() { + return isLatest; + } + + public String getTaskType() { + return taskType; + } + + public String getCheckpointId() { + return checkpointId; + } + + public AnomalyDetector getDetector() { + return detector; + } + + public String getCoordinatingNode() { + return coordinatingNode; + } + + public String getWorkerNode() { + return workerNode; + } + + public DateRange getDetectionDateRange() { + return detectionDateRange; + } + + public Entity getEntity() { + return entity; + } + + public String getEntityModelId() { + return entity == null ? null : entity.getModelId(getId()).orElse(null); + } + + public String getParentTaskId() { + return parentTaskId; + } + + public Integer getEstimatedMinutesLeft() { + return estimatedMinutesLeft; + } + + public User getUser() { + return user; + } + + public void setDetectionDateRange(DateRange detectionDateRange) { + this.detectionDateRange = detectionDateRange; + } + + public void setLatest(Boolean latest) { + isLatest = latest; + } + + public void setLastUpdateTime(Instant lastUpdateTime) { + this.lastUpdateTime = lastUpdateTime; + } +} diff --git a/src/main/java/org/opensearch/ad/model/ADTaskAction.java-e b/src/main/java/org/opensearch/ad/model/ADTaskAction.java-e new file mode 100644 index 000000000..b58b24cd5 --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/ADTaskAction.java-e @@ -0,0 +1,87 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import org.opensearch.action.ActionListener; +import org.opensearch.transport.TransportService; + +/** + * AD task action enum. Have 2 classes of task actions: + *
    + *
  • AD task actions which execute on coordinating node.
  • + *
  • Task slot actions which execute on lead node.
  • + *
+ */ +public enum ADTaskAction { + // ====================================== + // Actions execute on coordinating node + // ====================================== + /** + * Start historical analysis for detector. + *

Execute on coordinating node

+ */ + START, + /** + * Historical analysis finished, so we need to remove detector cache. Used for these cases + *
    + *
  • Single entity detector finished/failed/cancelled. Check ADBatchTaskRunner#internalBatchTaskListener
  • + *
  • Reset task state as stopped. Check ADTaskManager#resetTaskStateAsStopped
  • + *
  • When stop realtime job, we need to stop task and clean up cache. Check ADTaskManager#stopLatestRealtimeTask
  • + *
  • When start realtime job, will clean stale cache on old coordinating node. + * Check ADTaskManager#initRealtimeTaskCacheAndCleanupStaleCache
  • + *
+ */ + CLEAN_CACHE, + /** + * Cancel historical analysis. Currently only used for HC detector. Single entity detector just need + * to cancel itself. HC detector need to cancel detector level task on coordinating node. + *

Execute on coordinating node

+ */ + CANCEL, + /** + * Run next entity for HC detector historical analysis. If no entity, will set detector task as done. + *

Execute on coordinating node

+ */ + NEXT_ENTITY, + /** + * If any retryable exception happens for HC entity task like limit exceed exception, will push back + * entity to pending entities queue and run next entity. + *

Execute on coordinating node

+ */ + PUSH_BACK_ENTITY, + /** + * Clean stale entities in running entity queue, for example the work node crashed and fail to remove + * entity from running entity queue on coordinating node. + *

Execute on coordinating node

+ */ + CLEAN_STALE_RUNNING_ENTITIES, + /** + * Scale entity task slots for HC historical analysis. + * Check {@link org.opensearch.ad.task.ADTaskManager#runNextEntityForHCADHistorical(ADTask, TransportService, ActionListener)}. + *

Execute on coordinating node

+ */ + SCALE_ENTITY_TASK_SLOTS, + + // ====================================== + // Actions execute on lead node + // ====================================== + /** + * Apply for task slots when historical analysis starts. + *

Execute on lead node

+ */ + APPLY_FOR_TASK_SLOTS, + /** + * Check current available task slots in cluster. HC historical analysis need this to scale task slots. + *

Execute on lead node

+ */ + CHECK_AVAILABLE_TASK_SLOTS, +} diff --git a/src/main/java/org/opensearch/ad/model/ADTaskProfile.java b/src/main/java/org/opensearch/ad/model/ADTaskProfile.java index d95c6c579..cd6eaeaa0 100644 --- a/src/main/java/org/opensearch/ad/model/ADTaskProfile.java +++ b/src/main/java/org/opensearch/ad/model/ADTaskProfile.java @@ -11,7 +11,7 @@ package org.opensearch.ad.model; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; import java.util.ArrayList; @@ -19,9 +19,9 @@ import java.util.Objects; import org.opensearch.Version; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; diff --git a/src/main/java/org/opensearch/ad/model/ADTaskProfile.java-e b/src/main/java/org/opensearch/ad/model/ADTaskProfile.java-e new file mode 100644 index 000000000..cd6eaeaa0 --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/ADTaskProfile.java-e @@ -0,0 +1,603 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import org.opensearch.Version; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.annotation.Generated; + +/** + * One anomaly detection task means one detector starts to run until stopped. + */ +public class ADTaskProfile implements ToXContentObject, Writeable { + + public static final String AD_TASK_FIELD = "ad_task"; + public static final String SHINGLE_SIZE_FIELD = "shingle_size"; + public static final String RCF_TOTAL_UPDATES_FIELD = "rcf_total_updates"; + public static final String THRESHOLD_MODEL_TRAINED_FIELD = "threshold_model_trained"; + public static final String THRESHOLD_MODEL_TRAINING_DATA_SIZE_FIELD = "threshold_model_training_data_size"; + public static final String MODEL_SIZE_IN_BYTES = "model_size_in_bytes"; + public static final String NODE_ID_FIELD = "node_id"; + public static final String TASK_ID_FIELD = "task_id"; + public static final String AD_TASK_TYPE_FIELD = "task_type"; + public static final String DETECTOR_TASK_SLOTS_FIELD = "detector_task_slots"; + public static final String TOTAL_ENTITIES_INITED_FIELD = "total_entities_inited"; + public static final String TOTAL_ENTITIES_COUNT_FIELD = "total_entities_count"; + public static final String PENDING_ENTITIES_COUNT_FIELD = "pending_entities_count"; + public static final String RUNNING_ENTITIES_COUNT_FIELD = "running_entities_count"; + public static final String RUNNING_ENTITIES_FIELD = "running_entities"; + public static final String ENTITY_TASK_PROFILE_FIELD = "entity_task_profiles"; + public static final String LATEST_HC_TASK_RUN_TIME_FIELD = "latest_hc_task_run_time"; + + private ADTask adTask; + private Integer shingleSize; + private Long rcfTotalUpdates; + private Boolean thresholdModelTrained; + private Integer thresholdModelTrainingDataSize; + private Long modelSizeInBytes; + private String nodeId; + private String taskId; + private String adTaskType; + private Integer detectorTaskSlots; + private Boolean totalEntitiesInited; + private Integer totalEntitiesCount; + private Integer pendingEntitiesCount; + private Integer runningEntitiesCount; + private List runningEntities; + private Long latestHCTaskRunTime; + + private List entityTaskProfiles; + + public ADTaskProfile() { + + } + + public ADTaskProfile(ADTask adTask) { + this.adTask = adTask; + } + + public ADTaskProfile( + String taskId, + int shingleSize, + long rcfTotalUpdates, + boolean thresholdModelTrained, + int thresholdModelTrainingDataSize, + long modelSizeInBytes, + String nodeId + ) { + this.taskId = taskId; + this.shingleSize = shingleSize; + this.rcfTotalUpdates = rcfTotalUpdates; + this.thresholdModelTrained = thresholdModelTrained; + this.thresholdModelTrainingDataSize = thresholdModelTrainingDataSize; + this.modelSizeInBytes = modelSizeInBytes; + this.nodeId = nodeId; + } + + public ADTaskProfile( + ADTask adTask, + Integer shingleSize, + Long rcfTotalUpdates, + Boolean thresholdModelTrained, + Integer thresholdModelTrainingDataSize, + Long modelSizeInBytes, + String nodeId, + String taskId, + String adTaskType, + Integer detectorTaskSlots, + Boolean totalEntitiesInited, + Integer totalEntitiesCount, + Integer pendingEntitiesCount, + Integer runningEntitiesCount, + List runningEntities, + Long latestHCTaskRunTime + ) { + this.adTask = adTask; + this.shingleSize = shingleSize; + this.rcfTotalUpdates = rcfTotalUpdates; + this.thresholdModelTrained = thresholdModelTrained; + this.thresholdModelTrainingDataSize = thresholdModelTrainingDataSize; + this.modelSizeInBytes = modelSizeInBytes; + this.nodeId = nodeId; + this.taskId = taskId; + this.adTaskType = adTaskType; + this.detectorTaskSlots = detectorTaskSlots; + this.totalEntitiesInited = totalEntitiesInited; + this.totalEntitiesCount = totalEntitiesCount; + this.pendingEntitiesCount = pendingEntitiesCount; + this.runningEntitiesCount = runningEntitiesCount; + this.runningEntities = runningEntities; + this.latestHCTaskRunTime = latestHCTaskRunTime; + } + + public ADTaskProfile(StreamInput input) throws IOException { + if (input.readBoolean()) { + this.adTask = new ADTask(input); + } else { + this.adTask = null; + } + this.shingleSize = input.readOptionalInt(); + this.rcfTotalUpdates = input.readOptionalLong(); + this.thresholdModelTrained = input.readOptionalBoolean(); + this.thresholdModelTrainingDataSize = input.readOptionalInt(); + this.modelSizeInBytes = input.readOptionalLong(); + this.nodeId = input.readOptionalString(); + if (input.available() > 0) { + this.taskId = input.readOptionalString(); + this.adTaskType = input.readOptionalString(); + this.detectorTaskSlots = input.readOptionalInt(); + this.totalEntitiesInited = input.readOptionalBoolean(); + this.totalEntitiesCount = input.readOptionalInt(); + this.pendingEntitiesCount = input.readOptionalInt(); + this.runningEntitiesCount = input.readOptionalInt(); + if (input.readBoolean()) { + this.runningEntities = input.readStringList(); + } + if (input.readBoolean()) { + this.entityTaskProfiles = input.readList(ADEntityTaskProfile::new); + } + this.latestHCTaskRunTime = input.readOptionalLong(); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + writeTo(out, Version.CURRENT); + } + + public void writeTo(StreamOutput out, Version adVersion) throws IOException { + if (adTask != null) { + out.writeBoolean(true); + adTask.writeTo(out); + } else { + out.writeBoolean(false); + } + + out.writeOptionalInt(shingleSize); + out.writeOptionalLong(rcfTotalUpdates); + out.writeOptionalBoolean(thresholdModelTrained); + out.writeOptionalInt(thresholdModelTrainingDataSize); + out.writeOptionalLong(modelSizeInBytes); + out.writeOptionalString(nodeId); + if (adVersion != null) { + out.writeOptionalString(taskId); + out.writeOptionalString(adTaskType); + out.writeOptionalInt(detectorTaskSlots); + out.writeOptionalBoolean(totalEntitiesInited); + out.writeOptionalInt(totalEntitiesCount); + out.writeOptionalInt(pendingEntitiesCount); + out.writeOptionalInt(runningEntitiesCount); + if (runningEntities != null && runningEntities.size() > 0) { + out.writeBoolean(true); + out.writeStringCollection(runningEntities); + } else { + out.writeBoolean(false); + } + if (entityTaskProfiles != null && entityTaskProfiles.size() > 0) { + out.writeBoolean(true); + out.writeList(entityTaskProfiles); + } else { + out.writeBoolean(false); + } + out.writeOptionalLong(latestHCTaskRunTime); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + if (adTask != null) { + xContentBuilder.field(AD_TASK_FIELD, adTask); + } + if (shingleSize != null) { + xContentBuilder.field(SHINGLE_SIZE_FIELD, shingleSize); + } + if (rcfTotalUpdates != null) { + xContentBuilder.field(RCF_TOTAL_UPDATES_FIELD, rcfTotalUpdates); + } + if (thresholdModelTrained != null) { + xContentBuilder.field(THRESHOLD_MODEL_TRAINED_FIELD, thresholdModelTrained); + } + if (thresholdModelTrainingDataSize != null) { + xContentBuilder.field(THRESHOLD_MODEL_TRAINING_DATA_SIZE_FIELD, thresholdModelTrainingDataSize); + } + if (modelSizeInBytes != null) { + xContentBuilder.field(MODEL_SIZE_IN_BYTES, modelSizeInBytes); + } + if (nodeId != null) { + xContentBuilder.field(NODE_ID_FIELD, nodeId); + } + if (taskId != null) { + xContentBuilder.field(TASK_ID_FIELD, taskId); + } + if (adTaskType != null) { + xContentBuilder.field(AD_TASK_TYPE_FIELD, adTaskType); + } + if (detectorTaskSlots != null) { + xContentBuilder.field(DETECTOR_TASK_SLOTS_FIELD, detectorTaskSlots); + } + if (totalEntitiesInited != null) { + xContentBuilder.field(TOTAL_ENTITIES_INITED_FIELD, totalEntitiesInited); + } + if (totalEntitiesCount != null) { + xContentBuilder.field(TOTAL_ENTITIES_COUNT_FIELD, totalEntitiesCount); + } + if (pendingEntitiesCount != null) { + xContentBuilder.field(PENDING_ENTITIES_COUNT_FIELD, pendingEntitiesCount); + } + if (runningEntitiesCount != null) { + xContentBuilder.field(RUNNING_ENTITIES_COUNT_FIELD, runningEntitiesCount); + } + if (runningEntities != null) { + xContentBuilder.field(RUNNING_ENTITIES_FIELD, runningEntities); + } + if (entityTaskProfiles != null && entityTaskProfiles.size() > 0) { + xContentBuilder.field(ENTITY_TASK_PROFILE_FIELD, entityTaskProfiles.toArray()); + } + if (latestHCTaskRunTime != null) { + xContentBuilder.field(LATEST_HC_TASK_RUN_TIME_FIELD, latestHCTaskRunTime); + } + return xContentBuilder.endObject(); + } + + public static ADTaskProfile parse(XContentParser parser) throws IOException { + ADTask adTask = null; + Integer shingleSize = null; + Long rcfTotalUpdates = null; + Boolean thresholdModelTrained = null; + Integer thresholdModelTrainingDataSize = null; + Long modelSizeInBytes = null; + String nodeId = null; + String taskId = null; + String taskType = null; + Integer detectorTaskSlots = null; + Boolean totalEntitiesInited = null; + Integer totalEntitiesCount = null; + Integer pendingEntitiesCount = null; + Integer runningEntitiesCount = null; + List runningEntities = null; + List entityTaskProfiles = null; + Long latestHCTaskRunTime = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case AD_TASK_FIELD: + adTask = ADTask.parse(parser); + break; + case SHINGLE_SIZE_FIELD: + shingleSize = parser.intValue(); + break; + case RCF_TOTAL_UPDATES_FIELD: + rcfTotalUpdates = parser.longValue(); + break; + case THRESHOLD_MODEL_TRAINED_FIELD: + thresholdModelTrained = parser.booleanValue(); + break; + case THRESHOLD_MODEL_TRAINING_DATA_SIZE_FIELD: + thresholdModelTrainingDataSize = parser.intValue(); + break; + case MODEL_SIZE_IN_BYTES: + modelSizeInBytes = parser.longValue(); + break; + case NODE_ID_FIELD: + nodeId = parser.text(); + break; + case TASK_ID_FIELD: + taskId = parser.text(); + break; + case AD_TASK_TYPE_FIELD: + taskType = parser.text(); + break; + case DETECTOR_TASK_SLOTS_FIELD: + detectorTaskSlots = parser.intValue(); + break; + case TOTAL_ENTITIES_INITED_FIELD: + totalEntitiesInited = parser.booleanValue(); + break; + case TOTAL_ENTITIES_COUNT_FIELD: + totalEntitiesCount = parser.intValue(); + break; + case PENDING_ENTITIES_COUNT_FIELD: + pendingEntitiesCount = parser.intValue(); + break; + case RUNNING_ENTITIES_COUNT_FIELD: + runningEntitiesCount = parser.intValue(); + break; + case RUNNING_ENTITIES_FIELD: + runningEntities = new ArrayList<>(); + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + runningEntities.add(parser.text()); + } + break; + case ENTITY_TASK_PROFILE_FIELD: + entityTaskProfiles = new ArrayList<>(); + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + entityTaskProfiles.add(ADEntityTaskProfile.parse(parser)); + } + break; + case LATEST_HC_TASK_RUN_TIME_FIELD: + latestHCTaskRunTime = parser.longValue(); + break; + default: + parser.skipChildren(); + break; + } + } + return new ADTaskProfile( + adTask, + shingleSize, + rcfTotalUpdates, + thresholdModelTrained, + thresholdModelTrainingDataSize, + modelSizeInBytes, + nodeId, + taskId, + taskType, + detectorTaskSlots, + totalEntitiesInited, + totalEntitiesCount, + pendingEntitiesCount, + runningEntitiesCount, + runningEntities, + latestHCTaskRunTime + ); + } + + public ADTask getAdTask() { + return adTask; + } + + public void setAdTask(ADTask adTask) { + this.adTask = adTask; + } + + public Integer getShingleSize() { + return shingleSize; + } + + public void setShingleSize(Integer shingleSize) { + this.shingleSize = shingleSize; + } + + public Long getRcfTotalUpdates() { + return rcfTotalUpdates; + } + + public void setRcfTotalUpdates(Long rcfTotalUpdates) { + this.rcfTotalUpdates = rcfTotalUpdates; + } + + public Boolean getThresholdModelTrained() { + return thresholdModelTrained; + } + + public void setThresholdModelTrained(Boolean thresholdModelTrained) { + this.thresholdModelTrained = thresholdModelTrained; + } + + public Integer getThresholdModelTrainingDataSize() { + return thresholdModelTrainingDataSize; + } + + public void setThresholdModelTrainingDataSize(Integer thresholdModelTrainingDataSize) { + this.thresholdModelTrainingDataSize = thresholdModelTrainingDataSize; + } + + public Long getModelSizeInBytes() { + return modelSizeInBytes; + } + + public void setModelSizeInBytes(Long modelSizeInBytes) { + this.modelSizeInBytes = modelSizeInBytes; + } + + public String getNodeId() { + return nodeId; + } + + public void setNodeId(String nodeId) { + this.nodeId = nodeId; + } + + public String getTaskId() { + return taskId; + } + + public void setTaskId(String taskId) { + this.taskId = taskId; + } + + public String getAdTaskType() { + return adTaskType; + } + + public void setAdTaskType(String adTaskType) { + this.adTaskType = adTaskType; + } + + public boolean getTotalEntitiesInited() { + return totalEntitiesInited != null && totalEntitiesInited.booleanValue(); + } + + public void setTotalEntitiesInited(Boolean totalEntitiesInited) { + this.totalEntitiesInited = totalEntitiesInited; + } + + public Long getLatestHCTaskRunTime() { + return latestHCTaskRunTime; + } + + public void setLatestHCTaskRunTime(Long latestHCTaskRunTime) { + this.latestHCTaskRunTime = latestHCTaskRunTime; + } + + public Integer getTotalEntitiesCount() { + return totalEntitiesCount; + } + + public void setTotalEntitiesCount(Integer totalEntitiesCount) { + this.totalEntitiesCount = totalEntitiesCount; + } + + public Integer getDetectorTaskSlots() { + return detectorTaskSlots; + } + + public void setDetectorTaskSlots(Integer detectorTaskSlots) { + this.detectorTaskSlots = detectorTaskSlots; + } + + public Integer getPendingEntitiesCount() { + return pendingEntitiesCount; + } + + public void setPendingEntitiesCount(Integer pendingEntitiesCount) { + this.pendingEntitiesCount = pendingEntitiesCount; + } + + public Integer getRunningEntitiesCount() { + return runningEntitiesCount; + } + + public void setRunningEntitiesCount(Integer runningEntitiesCount) { + this.runningEntitiesCount = runningEntitiesCount; + } + + public List getRunningEntities() { + return runningEntities; + } + + public void setRunningEntities(List runningEntities) { + this.runningEntities = runningEntities; + } + + public List getEntityTaskProfiles() { + return entityTaskProfiles; + } + + public void setEntityTaskProfiles(List entityTaskProfiles) { + this.entityTaskProfiles = entityTaskProfiles; + } + + @Generated + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + ADTaskProfile that = (ADTaskProfile) o; + return Objects.equals(adTask, that.adTask) + && Objects.equals(shingleSize, that.shingleSize) + && Objects.equals(rcfTotalUpdates, that.rcfTotalUpdates) + && Objects.equals(thresholdModelTrained, that.thresholdModelTrained) + && Objects.equals(thresholdModelTrainingDataSize, that.thresholdModelTrainingDataSize) + && Objects.equals(modelSizeInBytes, that.modelSizeInBytes) + && Objects.equals(nodeId, that.nodeId) + && Objects.equals(taskId, that.taskId) + && Objects.equals(adTaskType, that.adTaskType) + && Objects.equals(detectorTaskSlots, that.detectorTaskSlots) + && Objects.equals(totalEntitiesInited, that.totalEntitiesInited) + && Objects.equals(totalEntitiesCount, that.totalEntitiesCount) + && Objects.equals(pendingEntitiesCount, that.pendingEntitiesCount) + && Objects.equals(runningEntitiesCount, that.runningEntitiesCount) + && Objects.equals(runningEntities, that.runningEntities) + && Objects.equals(latestHCTaskRunTime, that.latestHCTaskRunTime) + && Objects.equals(entityTaskProfiles, that.entityTaskProfiles); + } + + @Generated + @Override + public int hashCode() { + return Objects + .hash( + adTask, + shingleSize, + rcfTotalUpdates, + thresholdModelTrained, + thresholdModelTrainingDataSize, + modelSizeInBytes, + nodeId, + taskId, + adTaskType, + detectorTaskSlots, + totalEntitiesInited, + totalEntitiesCount, + pendingEntitiesCount, + runningEntitiesCount, + runningEntities, + entityTaskProfiles, + latestHCTaskRunTime + ); + } + + @Override + public String toString() { + return "ADTaskProfile{" + + "adTask=" + + adTask + + ", shingleSize=" + + shingleSize + + ", rcfTotalUpdates=" + + rcfTotalUpdates + + ", thresholdModelTrained=" + + thresholdModelTrained + + ", thresholdModelTrainingDataSize=" + + thresholdModelTrainingDataSize + + ", modelSizeInBytes=" + + modelSizeInBytes + + ", nodeId='" + + nodeId + + '\'' + + ", taskId='" + + taskId + + '\'' + + ", adTaskType='" + + adTaskType + + '\'' + + ", detectorTaskSlots=" + + detectorTaskSlots + + ", totalEntitiesInited=" + + totalEntitiesInited + + ", totalEntitiesCount=" + + totalEntitiesCount + + ", pendingEntitiesCount=" + + pendingEntitiesCount + + ", runningEntitiesCount=" + + runningEntitiesCount + + ", runningEntities=" + + runningEntities + + ", latestHCTaskRunTime=" + + latestHCTaskRunTime + + ", entityTaskProfiles=" + + entityTaskProfiles + + '}'; + } +} diff --git a/src/main/java/org/opensearch/ad/model/ADTaskState.java-e b/src/main/java/org/opensearch/ad/model/ADTaskState.java-e new file mode 100644 index 000000000..68462f816 --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/ADTaskState.java-e @@ -0,0 +1,62 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.util.List; + +import com.google.common.collect.ImmutableList; + +/** + * AD task states. + *
    + *
  • CREATED: + * When user start a historical detector, we will create one task to track the detector + * execution and set its state as CREATED + * + *
  • INIT: + * After task created, coordinate node will gather all eligible node’s state and dispatch + * task to the worker node with lowest load. When the worker node receives the request, + * it will set the task state as INIT immediately, then start to run cold start to train + * RCF model. We will track the initialization progress in task. + * Init_Progress=ModelUpdates/MinSampleSize + * + *
  • RUNNING: + * If RCF model gets enough data points and passed training, it will start to detect data + * normally and output positive anomaly scores. Once the RCF model starts to output positive + * anomaly score, we will set the task state as RUNNING and init progress as 100%. We will + * track task running progress in task: Task_Progress=DetectedPieces/AllPieces + * + *
  • FINISHED: + * When all historical data detected, we set the task state as FINISHED and task progress + * as 100%. + * + *
  • STOPPED: + * User can cancel a running task by stopping detector, for example, user want to tune + * feature and reran and don’t want current task run any more. When a historical detector + * stopped, we will mark the task flag cancelled as true, when run next piece, we will + * check this flag and stop the task. Then task stopped, will set its state as STOPPED + * + *
  • FAILED: + * If any exception happen, we will set task state as FAILED + *
+ */ +public enum ADTaskState { + CREATED, + INIT, + RUNNING, + FAILED, + STOPPED, + FINISHED; + + public static List NOT_ENDED_STATES = ImmutableList + .of(ADTaskState.CREATED.name(), ADTaskState.INIT.name(), ADTaskState.RUNNING.name()); +} diff --git a/src/main/java/org/opensearch/ad/model/ADTaskType.java-e b/src/main/java/org/opensearch/ad/model/ADTaskType.java-e new file mode 100644 index 000000000..b4e06aefc --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/ADTaskType.java-e @@ -0,0 +1,48 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.util.List; +import java.util.stream.Collectors; + +import com.google.common.collect.ImmutableList; + +public enum ADTaskType { + @Deprecated + HISTORICAL, + REALTIME_SINGLE_ENTITY, + REALTIME_HC_DETECTOR, + HISTORICAL_SINGLE_ENTITY, + // detector level task to track overall state, init progress, error etc. for HC detector + HISTORICAL_HC_DETECTOR, + // entity level task to track just one specific entity's state, init progress, error etc. + HISTORICAL_HC_ENTITY; + + public static List HISTORICAL_DETECTOR_TASK_TYPES = ImmutableList + .of(ADTaskType.HISTORICAL_HC_DETECTOR, ADTaskType.HISTORICAL_SINGLE_ENTITY, ADTaskType.HISTORICAL); + public static List ALL_HISTORICAL_TASK_TYPES = ImmutableList + .of(ADTaskType.HISTORICAL_HC_DETECTOR, ADTaskType.HISTORICAL_SINGLE_ENTITY, ADTaskType.HISTORICAL_HC_ENTITY, ADTaskType.HISTORICAL); + public static List REALTIME_TASK_TYPES = ImmutableList + .of(ADTaskType.REALTIME_SINGLE_ENTITY, ADTaskType.REALTIME_HC_DETECTOR); + public static List ALL_DETECTOR_TASK_TYPES = ImmutableList + .of( + ADTaskType.REALTIME_SINGLE_ENTITY, + ADTaskType.REALTIME_HC_DETECTOR, + ADTaskType.HISTORICAL_SINGLE_ENTITY, + ADTaskType.HISTORICAL_HC_DETECTOR, + ADTaskType.HISTORICAL + ); + + public static List taskTypeToString(List adTaskTypes) { + return adTaskTypes.stream().map(type -> type.name()).collect(Collectors.toList()); + } +} diff --git a/src/main/java/org/opensearch/ad/model/AnomalyDetector.java b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java index 008f21e4b..aa86fa842 100644 --- a/src/main/java/org/opensearch/ad/model/AnomalyDetector.java +++ b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java @@ -14,7 +14,7 @@ import static org.opensearch.ad.constant.ADCommonName.CUSTOM_RESULT_INDEX_PREFIX; import static org.opensearch.ad.model.AnomalyDetectorType.MULTI_ENTITY; import static org.opensearch.ad.model.AnomalyDetectorType.SINGLE_ENTITY; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder; import java.io.IOException; @@ -26,12 +26,12 @@ import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.settings.ADNumericSetting; -import org.opensearch.common.ParsingException; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.unit.TimeValue; import org.opensearch.commons.authuser.User; import org.opensearch.core.ParseField; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/model/AnomalyDetector.java-e b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java-e new file mode 100644 index 000000000..674cde43b --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/AnomalyDetector.java-e @@ -0,0 +1,516 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import static org.opensearch.ad.constant.ADCommonName.CUSTOM_RESULT_INDEX_PREFIX; +import static org.opensearch.ad.model.AnomalyDetectorType.MULTI_ENTITY; +import static org.opensearch.ad.model.AnomalyDetectorType.SINGLE_ENTITY; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.settings.ADNumericSetting; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.ParseField; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParseException; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonValue; +import org.opensearch.timeseries.dataprocessor.ImputationOption; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.TimeConfiguration; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.util.ParseUtils; + +/** + * An AnomalyDetector is used to represent anomaly detection model(RCF) related parameters. + * NOTE: If change detector config index mapping, you should change AD task index mapping as well. + * TODO: Will replace detector config mapping in AD task with detector config setting directly \ + * in code rather than config it in anomaly-detection-state.json file. + */ +public class AnomalyDetector extends Config { + + public static final String PARSE_FIELD_NAME = "AnomalyDetector"; + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( + AnomalyDetector.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) + ); + public static final String TYPE = "_doc"; + // for bwc, we have to keep this field instead of reusing an interval field in the super class. + // otherwise, we won't be able to recognize "detection_interval" field sent from old implementation. + public static final String DETECTION_INTERVAL_FIELD = "detection_interval"; + public static final String DETECTOR_TYPE_FIELD = "detector_type"; + @Deprecated + public static final String DETECTION_DATE_RANGE_FIELD = "detection_date_range"; + + protected String detectorType; + + // TODO: support backward compatibility, will remove in future + @Deprecated + private DateRange detectionDateRange; + + public static String INVALID_RESULT_INDEX_NAME_SIZE = "Result index name size must contains less than " + + MAX_RESULT_INDEX_NAME_SIZE + + " characters"; + + /** + * Constructor function. + * + * @param detectorId detector identifier + * @param version detector document version + * @param name detector name + * @param description description of detector + * @param timeField time field + * @param indices indices used as detector input + * @param features detector feature attributes + * @param filterQuery detector filter query + * @param detectionInterval detecting interval + * @param windowDelay max delay window for realtime data + * @param shingleSize number of the most recent time intervals to form a shingled data point + * @param uiMetadata metadata used by OpenSearch-Dashboards + * @param schemaVersion anomaly detector index mapping version + * @param lastUpdateTime detector's last update time + * @param categoryFields a list of partition fields + * @param user user to which detector is associated + * @param resultIndex result index + * @param imputationOption interpolation method and optional default values + */ + public AnomalyDetector( + String detectorId, + Long version, + String name, + String description, + String timeField, + List indices, + List features, + QueryBuilder filterQuery, + TimeConfiguration detectionInterval, + TimeConfiguration windowDelay, + Integer shingleSize, + Map uiMetadata, + Integer schemaVersion, + Instant lastUpdateTime, + List categoryFields, + User user, + String resultIndex, + ImputationOption imputationOption + ) { + super( + detectorId, + version, + name, + description, + timeField, + indices, + features, + filterQuery, + windowDelay, + shingleSize, + uiMetadata, + schemaVersion, + lastUpdateTime, + categoryFields, + user, + resultIndex, + detectionInterval, + imputationOption + ); + + checkAndThrowValidationErrors(ValidationAspect.DETECTOR); + + if (detectionInterval == null) { + errorMessage = ADCommonMessages.NULL_DETECTION_INTERVAL; + issueType = ValidationIssueType.DETECTION_INTERVAL; + } else if (((IntervalTimeConfiguration) detectionInterval).getInterval() <= 0) { + errorMessage = ADCommonMessages.INVALID_DETECTION_INTERVAL; + issueType = ValidationIssueType.DETECTION_INTERVAL; + } + + int maxCategoryFields = ADNumericSetting.maxCategoricalFields(); + if (categoryFields != null && categoryFields.size() > maxCategoryFields) { + errorMessage = CommonMessages.getTooManyCategoricalFieldErr(maxCategoryFields); + issueType = ValidationIssueType.CATEGORY; + } + + checkAndThrowValidationErrors(ValidationAspect.DETECTOR); + + this.detectorType = isHC(categoryFields) ? MULTI_ENTITY.name() : SINGLE_ENTITY.name(); + } + + /* + * For backward compatiblity reason, we cannot use super class + * Config's constructor as we have detectionDateRange and + * detectorType that Config does not have. + */ + public AnomalyDetector(StreamInput input) throws IOException { + id = input.readOptionalString(); + version = input.readOptionalLong(); + name = input.readString(); + description = input.readOptionalString(); + timeField = input.readString(); + indices = input.readStringList(); + featureAttributes = input.readList(Feature::new); + filterQuery = input.readNamedWriteable(QueryBuilder.class); + interval = IntervalTimeConfiguration.readFrom(input); + windowDelay = IntervalTimeConfiguration.readFrom(input); + shingleSize = input.readInt(); + schemaVersion = input.readInt(); + this.categoryFields = input.readOptionalStringList(); + lastUpdateTime = input.readInstant(); + if (input.readBoolean()) { + this.user = new User(input); + } else { + user = null; + } + if (input.readBoolean()) { + detectionDateRange = new DateRange(input); + } else { + detectionDateRange = null; + } + detectorType = input.readOptionalString(); + if (input.readBoolean()) { + this.uiMetadata = input.readMap(); + } else { + this.uiMetadata = null; + } + customResultIndex = input.readOptionalString(); + if (input.readBoolean()) { + this.imputationOption = new ImputationOption(input); + } else { + this.imputationOption = null; + } + this.imputer = createImputer(); + } + + public XContentBuilder toXContent(XContentBuilder builder) throws IOException { + return toXContent(builder, ToXContent.EMPTY_PARAMS); + } + + /* + * For backward compatiblity reason, we cannot use super class + * Config's writeTo as we have detectionDateRange and + * detectorType that Config does not have. + */ + @Override + public void writeTo(StreamOutput output) throws IOException { + output.writeOptionalString(id); + output.writeOptionalLong(version); + output.writeString(name); + output.writeOptionalString(description); + output.writeString(timeField); + output.writeStringCollection(indices); + output.writeList(featureAttributes); + output.writeNamedWriteable(filterQuery); + interval.writeTo(output); + windowDelay.writeTo(output); + output.writeInt(shingleSize); + output.writeInt(schemaVersion); + output.writeOptionalStringCollection(categoryFields); + output.writeInstant(lastUpdateTime); + if (user != null) { + output.writeBoolean(true); // user exists + user.writeTo(output); + } else { + output.writeBoolean(false); // user does not exist + } + if (detectionDateRange != null) { + output.writeBoolean(true); // detectionDateRange exists + detectionDateRange.writeTo(output); + } else { + output.writeBoolean(false); // detectionDateRange does not exist + } + output.writeOptionalString(detectorType); + if (uiMetadata != null) { + output.writeBoolean(true); + output.writeMap(uiMetadata); + } else { + output.writeBoolean(false); + } + output.writeOptionalString(customResultIndex); + if (imputationOption != null) { + output.writeBoolean(true); + imputationOption.writeTo(output); + } else { + output.writeBoolean(false); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + xContentBuilder = super.toXContent(xContentBuilder, params); + xContentBuilder.field(DETECTION_INTERVAL_FIELD, interval); + + if (detectorType != null) { + xContentBuilder.field(DETECTOR_TYPE_FIELD, detectorType); + } + if (detectionDateRange != null) { + xContentBuilder.field(DETECTION_DATE_RANGE_FIELD, detectionDateRange); + } + + return xContentBuilder.endObject(); + } + + /** + * Parse raw json content into anomaly detector instance. + * + * @param parser json based content parser + * @return anomaly detector instance + * @throws IOException IOException if content can't be parsed correctly + */ + public static AnomalyDetector parse(XContentParser parser) throws IOException { + return parse(parser, null); + } + + public static AnomalyDetector parse(XContentParser parser, String detectorId) throws IOException { + return parse(parser, detectorId, null); + } + + /** + * Parse raw json content and given detector id into anomaly detector instance. + * + * @param parser json based content parser + * @param detectorId detector id + * @param version detector document version + * @return anomaly detector instance + * @throws IOException IOException if content can't be parsed correctly + */ + public static AnomalyDetector parse(XContentParser parser, String detectorId, Long version) throws IOException { + return parse(parser, detectorId, version, null, null); + } + + /** + * Parse raw json content and given detector id into anomaly detector instance. + * + * @param parser json based content parser + * @param detectorId detector id + * @param version detector document version + * @param defaultDetectionInterval default detection interval + * @param defaultDetectionWindowDelay default detection window delay + * @return anomaly detector instance + * @throws IOException IOException if content can't be parsed correctly + */ + public static AnomalyDetector parse( + XContentParser parser, + String detectorId, + Long version, + TimeValue defaultDetectionInterval, + TimeValue defaultDetectionWindowDelay + ) throws IOException { + String name = null; + String description = ""; + String timeField = null; + List indices = new ArrayList(); + QueryBuilder filterQuery = QueryBuilders.matchAllQuery(); + TimeConfiguration detectionInterval = defaultDetectionInterval == null + ? null + : new IntervalTimeConfiguration(defaultDetectionInterval.getMinutes(), ChronoUnit.MINUTES); + TimeConfiguration windowDelay = defaultDetectionWindowDelay == null + ? null + : new IntervalTimeConfiguration(defaultDetectionWindowDelay.getSeconds(), ChronoUnit.SECONDS); + Integer shingleSize = null; + List features = new ArrayList<>(); + Integer schemaVersion = CommonValue.NO_SCHEMA_VERSION; + Map uiMetadata = null; + Instant lastUpdateTime = null; + User user = null; + DateRange detectionDateRange = null; + String resultIndex = null; + + List categoryField = null; + ImputationOption imputationOption = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case NAME_FIELD: + name = parser.text(); + break; + case DESCRIPTION_FIELD: + description = parser.text(); + break; + case TIMEFIELD_FIELD: + timeField = parser.text(); + break; + case INDICES_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + indices.add(parser.text()); + } + break; + case UI_METADATA_FIELD: + uiMetadata = parser.map(); + break; + case org.opensearch.timeseries.constant.CommonName.SCHEMA_VERSION_FIELD: + schemaVersion = parser.intValue(); + break; + case FILTER_QUERY_FIELD: + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + try { + filterQuery = parseInnerQueryBuilder(parser); + } catch (ParsingException | XContentParseException e) { + throw new ValidationException( + "Custom query error in data filter: " + e.getMessage(), + ValidationIssueType.FILTER_QUERY, + ValidationAspect.DETECTOR + ); + } catch (IllegalArgumentException e) { + if (!e.getMessage().contains("empty clause")) { + throw e; + } + } + break; + case DETECTION_INTERVAL_FIELD: + try { + detectionInterval = TimeConfiguration.parse(parser); + } catch (Exception e) { + if (e instanceof IllegalArgumentException && e.getMessage().contains(CommonMessages.NEGATIVE_TIME_CONFIGURATION)) { + throw new ValidationException( + "Detection interval must be a positive integer", + ValidationIssueType.DETECTION_INTERVAL, + ValidationAspect.DETECTOR + ); + } + throw e; + } + break; + case FEATURE_ATTRIBUTES_FIELD: + try { + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + features.add(Feature.parse(parser)); + } + } catch (Exception e) { + if (e instanceof ParsingException || e instanceof XContentParseException) { + throw new ValidationException( + "Custom query error: " + e.getMessage(), + ValidationIssueType.FEATURE_ATTRIBUTES, + ValidationAspect.DETECTOR + ); + } + throw e; + } + break; + case WINDOW_DELAY_FIELD: + try { + windowDelay = TimeConfiguration.parse(parser); + } catch (Exception e) { + if (e instanceof IllegalArgumentException && e.getMessage().contains(CommonMessages.NEGATIVE_TIME_CONFIGURATION)) { + throw new ValidationException( + "Window delay interval must be a positive integer", + ValidationIssueType.WINDOW_DELAY, + ValidationAspect.DETECTOR + ); + } + throw e; + } + break; + case SHINGLE_SIZE_FIELD: + shingleSize = parser.intValue(); + break; + case LAST_UPDATE_TIME_FIELD: + lastUpdateTime = ParseUtils.toInstant(parser); + break; + case CATEGORY_FIELD: + categoryField = (List) parser.list(); + break; + case USER_FIELD: + user = User.parse(parser); + break; + case DETECTION_DATE_RANGE_FIELD: + detectionDateRange = DateRange.parse(parser); + break; + case RESULT_INDEX_FIELD: + resultIndex = parser.text(); + break; + case IMPUTATION_OPTION_FIELD: + imputationOption = ImputationOption.parse(parser); + break; + default: + parser.skipChildren(); + break; + } + } + AnomalyDetector detector = new AnomalyDetector( + detectorId, + version, + name, + description, + timeField, + indices, + features, + filterQuery, + detectionInterval, + windowDelay, + getShingleSize(shingleSize), + uiMetadata, + schemaVersion, + lastUpdateTime, + categoryField, + user, + resultIndex, + imputationOption + ); + detector.setDetectionDateRange(detectionDateRange); + return detector; + } + + public String getDetectorType() { + return detectorType; + } + + public void setDetectionDateRange(DateRange detectionDateRange) { + this.detectionDateRange = detectionDateRange; + } + + public DateRange getDetectionDateRange() { + return detectionDateRange; + } + + @Override + protected ValidationAspect getConfigValidationAspect() { + return ValidationAspect.DETECTOR; + } + + @Override + public String validateCustomResultIndex(String resultIndex) { + if (resultIndex != null && !resultIndex.startsWith(CUSTOM_RESULT_INDEX_PREFIX)) { + return ADCommonMessages.INVALID_RESULT_INDEX_PREFIX; + } + return super.validateCustomResultIndex(resultIndex); + } +} diff --git a/src/main/java/org/opensearch/ad/model/AnomalyDetectorExecutionInput.java b/src/main/java/org/opensearch/ad/model/AnomalyDetectorExecutionInput.java index a8af32bb8..b2a45c9bb 100644 --- a/src/main/java/org/opensearch/ad/model/AnomalyDetectorExecutionInput.java +++ b/src/main/java/org/opensearch/ad/model/AnomalyDetectorExecutionInput.java @@ -11,7 +11,7 @@ package org.opensearch.ad.model; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; import java.time.Instant; diff --git a/src/main/java/org/opensearch/ad/model/AnomalyDetectorExecutionInput.java-e b/src/main/java/org/opensearch/ad/model/AnomalyDetectorExecutionInput.java-e new file mode 100644 index 000000000..b2a45c9bb --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/AnomalyDetectorExecutionInput.java-e @@ -0,0 +1,139 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.time.Instant; + +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.annotation.Generated; +import org.opensearch.timeseries.util.ParseUtils; + +import com.google.common.base.Objects; + +/** + * Input data needed to trigger anomaly detector. + */ +public class AnomalyDetectorExecutionInput implements ToXContentObject { + + private static final String DETECTOR_ID_FIELD = "detector_id"; + private static final String PERIOD_START_FIELD = "period_start"; + private static final String PERIOD_END_FIELD = "period_end"; + private static final String DETECTOR_FIELD = "detector"; + private Instant periodStart; + private Instant periodEnd; + private String detectorId; + private AnomalyDetector detector; + + public AnomalyDetectorExecutionInput(String detectorId, Instant periodStart, Instant periodEnd, AnomalyDetector detector) { + this.periodStart = periodStart; + this.periodEnd = periodEnd; + this.detectorId = detectorId; + this.detector = detector; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder + .startObject() + .field(DETECTOR_ID_FIELD, detectorId) + .field(PERIOD_START_FIELD, periodStart.toEpochMilli()) + .field(PERIOD_END_FIELD, periodEnd.toEpochMilli()) + .field(DETECTOR_FIELD, detector); + return xContentBuilder.endObject(); + } + + public static AnomalyDetectorExecutionInput parse(XContentParser parser) throws IOException { + return parse(parser, null); + } + + public static AnomalyDetectorExecutionInput parse(XContentParser parser, String adId) throws IOException { + Instant periodStart = null; + Instant periodEnd = null; + AnomalyDetector detector = null; + String detectorId = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case DETECTOR_ID_FIELD: + detectorId = parser.text(); + break; + case PERIOD_START_FIELD: + periodStart = ParseUtils.toInstant(parser); + break; + case PERIOD_END_FIELD: + periodEnd = ParseUtils.toInstant(parser); + break; + case DETECTOR_FIELD: + if (parser.currentToken().equals(XContentParser.Token.START_OBJECT)) { + detector = AnomalyDetector.parse(parser, detectorId); + } + break; + default: + break; + } + } + if (!Strings.isNullOrEmpty(adId)) { + detectorId = adId; + } + return new AnomalyDetectorExecutionInput(detectorId, periodStart, periodEnd, detector); + } + + @Generated + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + AnomalyDetectorExecutionInput that = (AnomalyDetectorExecutionInput) o; + return Objects.equal(getPeriodStart(), that.getPeriodStart()) + && Objects.equal(getPeriodEnd(), that.getPeriodEnd()) + && Objects.equal(getDetectorId(), that.getDetectorId()) + && Objects.equal(getDetector(), that.getDetector()); + } + + @Generated + @Override + public int hashCode() { + return Objects.hashCode(periodStart, periodEnd, detectorId); + } + + public Instant getPeriodStart() { + return periodStart; + } + + public Instant getPeriodEnd() { + return periodEnd; + } + + public String getDetectorId() { + return detectorId; + } + + public AnomalyDetector getDetector() { + return detector; + } + + public void setDetectorId(String detectorId) { + this.detectorId = detectorId; + } +} diff --git a/src/main/java/org/opensearch/ad/model/AnomalyDetectorJob.java b/src/main/java/org/opensearch/ad/model/AnomalyDetectorJob.java index 5ff9b07ef..7ef5ae528 100644 --- a/src/main/java/org/opensearch/ad/model/AnomalyDetectorJob.java +++ b/src/main/java/org/opensearch/ad/model/AnomalyDetectorJob.java @@ -12,16 +12,16 @@ package org.opensearch.ad.model; import static org.opensearch.ad.settings.AnomalyDetectorSettings.DEFAULT_AD_JOB_LOC_DURATION_SECONDS; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; import java.time.Instant; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; import org.opensearch.commons.authuser.User; import org.opensearch.core.ParseField; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/model/AnomalyDetectorJob.java-e b/src/main/java/org/opensearch/ad/model/AnomalyDetectorJob.java-e new file mode 100644 index 000000000..2762f2f70 --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/AnomalyDetectorJob.java-e @@ -0,0 +1,306 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.DEFAULT_AD_JOB_LOC_DURATION_SECONDS; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.time.Instant; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.ParseField; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.jobscheduler.spi.schedule.CronSchedule; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.jobscheduler.spi.schedule.Schedule; +import org.opensearch.jobscheduler.spi.schedule.ScheduleParser; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.TimeConfiguration; +import org.opensearch.timeseries.util.ParseUtils; + +import com.google.common.base.Objects; + +/** + * Anomaly detector job. + */ +public class AnomalyDetectorJob implements Writeable, ToXContentObject, ScheduledJobParameter { + enum ScheduleType { + CRON, + INTERVAL + } + + public static final String PARSE_FIELD_NAME = "AnomalyDetectorJob"; + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( + AnomalyDetectorJob.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) + ); + + public static final String NAME_FIELD = "name"; + public static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; + public static final String LOCK_DURATION_SECONDS = "lock_duration_seconds"; + + public static final String SCHEDULE_FIELD = "schedule"; + public static final String WINDOW_DELAY_FIELD = "window_delay"; + public static final String IS_ENABLED_FIELD = "enabled"; + public static final String ENABLED_TIME_FIELD = "enabled_time"; + public static final String DISABLED_TIME_FIELD = "disabled_time"; + public static final String USER_FIELD = "user"; + private static final String RESULT_INDEX_FIELD = "result_index"; + + private final String name; + private final Schedule schedule; + private final TimeConfiguration windowDelay; + private final Boolean isEnabled; + private final Instant enabledTime; + private final Instant disabledTime; + private final Instant lastUpdateTime; + private final Long lockDurationSeconds; + private final User user; + private String resultIndex; + + public AnomalyDetectorJob( + String name, + Schedule schedule, + TimeConfiguration windowDelay, + Boolean isEnabled, + Instant enabledTime, + Instant disabledTime, + Instant lastUpdateTime, + Long lockDurationSeconds, + User user, + String resultIndex + ) { + this.name = name; + this.schedule = schedule; + this.windowDelay = windowDelay; + this.isEnabled = isEnabled; + this.enabledTime = enabledTime; + this.disabledTime = disabledTime; + this.lastUpdateTime = lastUpdateTime; + this.lockDurationSeconds = lockDurationSeconds; + this.user = user; + this.resultIndex = resultIndex; + } + + public AnomalyDetectorJob(StreamInput input) throws IOException { + name = input.readString(); + if (input.readEnum(AnomalyDetectorJob.ScheduleType.class) == ScheduleType.CRON) { + schedule = new CronSchedule(input); + } else { + schedule = new IntervalSchedule(input); + } + windowDelay = IntervalTimeConfiguration.readFrom(input); + isEnabled = input.readBoolean(); + enabledTime = input.readInstant(); + disabledTime = input.readInstant(); + lastUpdateTime = input.readInstant(); + lockDurationSeconds = input.readLong(); + if (input.readBoolean()) { + user = new User(input); + } else { + user = null; + } + resultIndex = input.readOptionalString(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder + .startObject() + .field(NAME_FIELD, name) + .field(SCHEDULE_FIELD, schedule) + .field(WINDOW_DELAY_FIELD, windowDelay) + .field(IS_ENABLED_FIELD, isEnabled) + .field(ENABLED_TIME_FIELD, enabledTime.toEpochMilli()) + .field(LAST_UPDATE_TIME_FIELD, lastUpdateTime.toEpochMilli()) + .field(LOCK_DURATION_SECONDS, lockDurationSeconds); + if (disabledTime != null) { + xContentBuilder.field(DISABLED_TIME_FIELD, disabledTime.toEpochMilli()); + } + if (user != null) { + xContentBuilder.field(USER_FIELD, user); + } + if (resultIndex != null) { + xContentBuilder.field(RESULT_INDEX_FIELD, resultIndex); + } + return xContentBuilder.endObject(); + } + + @Override + public void writeTo(StreamOutput output) throws IOException { + output.writeString(name); + if (schedule instanceof CronSchedule) { + output.writeEnum(ScheduleType.CRON); + } else { + output.writeEnum(ScheduleType.INTERVAL); + } + schedule.writeTo(output); + windowDelay.writeTo(output); + output.writeBoolean(isEnabled); + output.writeInstant(enabledTime); + output.writeInstant(disabledTime); + output.writeInstant(lastUpdateTime); + output.writeLong(lockDurationSeconds); + if (user != null) { + output.writeBoolean(true); // user exists + user.writeTo(output); + } else { + output.writeBoolean(false); // user does not exist + } + output.writeOptionalString(resultIndex); + } + + public static AnomalyDetectorJob parse(XContentParser parser) throws IOException { + String name = null; + Schedule schedule = null; + TimeConfiguration windowDelay = null; + // we cannot set it to null as isEnabled() would do the unboxing and results in null pointer exception + Boolean isEnabled = Boolean.FALSE; + Instant enabledTime = null; + Instant disabledTime = null; + Instant lastUpdateTime = null; + Long lockDurationSeconds = DEFAULT_AD_JOB_LOC_DURATION_SECONDS; + User user = null; + String resultIndex = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case NAME_FIELD: + name = parser.text(); + break; + case SCHEDULE_FIELD: + schedule = ScheduleParser.parse(parser); + break; + case WINDOW_DELAY_FIELD: + windowDelay = TimeConfiguration.parse(parser); + break; + case IS_ENABLED_FIELD: + isEnabled = parser.booleanValue(); + break; + case ENABLED_TIME_FIELD: + enabledTime = ParseUtils.toInstant(parser); + break; + case DISABLED_TIME_FIELD: + disabledTime = ParseUtils.toInstant(parser); + break; + case LAST_UPDATE_TIME_FIELD: + lastUpdateTime = ParseUtils.toInstant(parser); + break; + case LOCK_DURATION_SECONDS: + lockDurationSeconds = parser.longValue(); + break; + case USER_FIELD: + user = User.parse(parser); + break; + case RESULT_INDEX_FIELD: + resultIndex = parser.text(); + break; + default: + parser.skipChildren(); + break; + } + } + return new AnomalyDetectorJob( + name, + schedule, + windowDelay, + isEnabled, + enabledTime, + disabledTime, + lastUpdateTime, + lockDurationSeconds, + user, + resultIndex + ); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + AnomalyDetectorJob that = (AnomalyDetectorJob) o; + return Objects.equal(getName(), that.getName()) + && Objects.equal(getSchedule(), that.getSchedule()) + && Objects.equal(isEnabled(), that.isEnabled()) + && Objects.equal(getEnabledTime(), that.getEnabledTime()) + && Objects.equal(getDisabledTime(), that.getDisabledTime()) + && Objects.equal(getLastUpdateTime(), that.getLastUpdateTime()) + && Objects.equal(getLockDurationSeconds(), that.getLockDurationSeconds()) + && Objects.equal(getCustomResultIndex(), that.getCustomResultIndex()); + } + + @Override + public int hashCode() { + return Objects.hashCode(name, schedule, isEnabled, enabledTime, lastUpdateTime); + } + + @Override + public String getName() { + return name; + } + + @Override + public Schedule getSchedule() { + return schedule; + } + + public TimeConfiguration getWindowDelay() { + return windowDelay; + } + + @Override + public boolean isEnabled() { + return isEnabled; + } + + @Override + public Instant getEnabledTime() { + return enabledTime; + } + + public Instant getDisabledTime() { + return disabledTime; + } + + @Override + public Instant getLastUpdateTime() { + return lastUpdateTime; + } + + @Override + public Long getLockDurationSeconds() { + return lockDurationSeconds; + } + + public User getUser() { + return user; + } + + public String getCustomResultIndex() { + return resultIndex; + } +} diff --git a/src/main/java/org/opensearch/ad/model/AnomalyDetectorType.java-e b/src/main/java/org/opensearch/ad/model/AnomalyDetectorType.java-e new file mode 100644 index 000000000..dee395c0b --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/AnomalyDetectorType.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +public enum AnomalyDetectorType { + @Deprecated + REALTIME_SINGLE_ENTITY, + @Deprecated + REALTIME_MULTI_ENTITY, + @Deprecated + HISTORICAL_SINGLE_ENTITY, + @Deprecated + HISTORICAL_MULTI_ENTITY, + + SINGLE_ENTITY, + MULTI_ENTITY, +} diff --git a/src/main/java/org/opensearch/ad/model/AnomalyResult.java b/src/main/java/org/opensearch/ad/model/AnomalyResult.java index f8222651b..1248b5489 100644 --- a/src/main/java/org/opensearch/ad/model/AnomalyResult.java +++ b/src/main/java/org/opensearch/ad/model/AnomalyResult.java @@ -12,7 +12,7 @@ package org.opensearch.ad.model; import static org.opensearch.ad.constant.ADCommonName.DUMMY_DETECTOR_ID; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; import java.time.Instant; @@ -26,10 +26,10 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.ad.ml.ThresholdingResult; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.commons.authuser.User; import org.opensearch.core.ParseField; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; diff --git a/src/main/java/org/opensearch/ad/model/AnomalyResult.java-e b/src/main/java/org/opensearch/ad/model/AnomalyResult.java-e new file mode 100644 index 000000000..baea173af --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/AnomalyResult.java-e @@ -0,0 +1,827 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import static org.opensearch.ad.constant.ADCommonName.DUMMY_DETECTOR_ID; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import org.apache.commons.lang.builder.ToStringBuilder; +import org.apache.commons.lang3.StringUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.ParseField; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.annotation.Generated; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.constant.CommonValue; +import org.opensearch.timeseries.model.DataByFeatureId; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.util.ParseUtils; + +import com.google.common.base.Objects; + +/** + * Include result returned from RCF model and feature data. + */ +public class AnomalyResult extends IndexableResult { + private static final Logger LOG = LogManager.getLogger(ThresholdingResult.class); + public static final String PARSE_FIELD_NAME = "AnomalyResult"; + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( + AnomalyResult.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) + ); + + public static final String DETECTOR_ID_FIELD = "detector_id"; + public static final String ANOMALY_SCORE_FIELD = "anomaly_score"; + public static final String ANOMALY_GRADE_FIELD = "anomaly_grade"; + public static final String APPROX_ANOMALY_START_FIELD = "approx_anomaly_start_time"; + public static final String RELEVANT_ATTRIBUTION_FIELD = "relevant_attribution"; + public static final String PAST_VALUES_FIELD = "past_values"; + public static final String EXPECTED_VALUES_FIELD = "expected_values"; + public static final String THRESHOLD_FIELD = "threshold"; + // unused currently. added since odfe 1.4 + public static final String IS_ANOMALY_FIELD = "is_anomaly"; + + private final Double anomalyScore; + private final Double anomalyGrade; + + /** + * the approximate time of current anomaly. We might detect anomaly late. This field + * is the approximate anomaly time. I called it approximate because rcf may + * not receive continuous data. To make it precise, I have to query previous + * anomaly results and find the what timestamp correspond to a few data points + * back. Instead, RCF returns the index of anomaly relative to current timestamp. + * So approAnomalyStartTime is current time + interval * relativeIndex + * Note {@code relativeIndex <= 0}. If the shingle size is 4, for example shingle is + * [0, 0, 1, 0], and this shingle is detected as anomaly, and actually the + * anomaly is caused by the third item "1", then the relativeIndex will be + * -1. + */ + private final Instant approxAnomalyStartTime; + + // a flattened version denoting the basic contribution of each input variable + private final List relevantAttribution; + + /* + pastValues is related to relativeIndex, startOfAnomaly and anomaly grade. + So if we detect anomaly late, we get the baseDimension values from the past (current is 0). + That is, we look back relativeIndex * baseDimensions. + + For example, current shingle is + "currentValues": [ + 6819.0, + 2375.3333333333335, + 0.0, + 49882.0, + 92070.0, + 5084.0, + 2072.809523809524, + 0.0, + 43529.0, + 91169.0, + 8129.0, + 2582.892857142857, + 12.0, + 54241.0, + 84596.0, + 11174.0, + 3092.9761904761904, + 24.0, + 64952.0, + 78024.0, + 14220.0, + 3603.059523809524, + 37.0, + 75664.0, + 71451.0, + 17265.0, + 4113.142857142857, + 49.0, + 86376.0, + 64878.0, + 16478.0, + 3761.4166666666665, + 37.0, + 78990.0, + 70057.0, + 15691.0, + 3409.690476190476, + 24.0, + 71604.0, + 75236.0 + ], + Since rcf returns relativeIndex is -2, we look back baseDimension * 2 and get the pastValues: + "pastValues": [ + 17265.0, + 4113.142857142857, + 49.0, + 86376.0, + 64878.0 + ], + + So pastValues is null when relativeIndex is 0 or startOfAnomaly is true + or the current shingle is not an anomaly. + + In the UX, if pastValues value is null, we can just show attribution/expected + value and it is implicit this is due to current input; if pastValues is not + null, it means the the attribution/expected values are from an old value + (e.g., 2 steps ago with data [1,2,3]) and we can add a text to explain that. + */ + private final List pastValues; + + /* + * The expected value is only calculated for anomalous detection intervals, + * and will generate expected value for each feature if detector has multiple + * features. + * Currently we expect one set of expected values. In the future, we + * might give different expected values with differently likelihood. So + * the two-dimensional array allows us to future-proof our applications. + * Also, expected values correspond to pastValues if present or current input + * point otherwise. If pastValues is present, we can add a text on UX to explain + * we found an anomaly from the past. + Example: + "expected_value": [{ + "likelihood": 0.8, + "value_list": [{ + "feature_id": "blah", + "value": 1 + }, + { + "feature_id": "blah2", + "value": 1 + } + ] + }]*/ + private final List expectedValuesList; + + // rcf score threshold at the time of writing a result + private final Double threshold; + protected final Double confidence; + + // used when indexing exception or error or an empty result + public AnomalyResult( + String detectorId, + String taskId, + List featureData, + Instant dataStartTime, + Instant dataEndTime, + Instant executionStartTime, + Instant executionEndTime, + String error, + Optional entity, + User user, + Integer schemaVersion, + String modelId + ) { + this( + detectorId, + taskId, + Double.NaN, + Double.NaN, + Double.NaN, + featureData, + dataStartTime, + dataEndTime, + executionStartTime, + executionEndTime, + error, + entity, + user, + schemaVersion, + modelId, + null, + null, + null, + null, + null + ); + } + + public AnomalyResult( + String configId, + String taskId, + Double anomalyScore, + Double anomalyGrade, + Double confidence, + List featureData, + Instant dataStartTime, + Instant dataEndTime, + Instant executionStartTime, + Instant executionEndTime, + String error, + Optional entity, + User user, + Integer schemaVersion, + String modelId, + Instant approxAnomalyStartTime, + List relevantAttribution, + List pastValues, + List expectedValuesList, + Double threshold + ) { + super( + configId, + featureData, + dataStartTime, + dataEndTime, + executionStartTime, + executionEndTime, + error, + entity, + user, + schemaVersion, + modelId, + taskId + ); + this.confidence = confidence; + this.anomalyScore = anomalyScore; + this.anomalyGrade = anomalyGrade; + this.approxAnomalyStartTime = approxAnomalyStartTime; + this.relevantAttribution = relevantAttribution; + this.pastValues = pastValues; + this.expectedValuesList = expectedValuesList; + this.threshold = threshold; + } + + /** + * Factory method that converts raw rcf results to an instance of AnomalyResult + * @param detectorId Detector Id + * @param intervalMillis Detector interval + * @param taskId Task Id + * @param rcfScore RCF score + * @param grade anomaly grade + * @param confidence data confidence + * @param featureData Feature data + * @param dataStartTime Data start time + * @param dataEndTime Data end time + * @param executionStartTime Execution start time + * @param executionEndTime Execution end time + * @param error Error + * @param entity Entity accessor + * @param user the user who created a detector + * @param schemaVersion Result schema version + * @param modelId Model Id + * @param relevantAttribution Attribution of the anomaly + * @param relativeIndex The index of anomaly point relative to current point. + * @param pastValues The input that caused anomaly if we detector anomaly late + * @param expectedValuesList Expected values + * @param likelihoodOfValues Likelihood of the expected values + * @param threshold Current threshold + * @return the converted AnomalyResult instance + */ + public static AnomalyResult fromRawTRCFResult( + String detectorId, + long intervalMillis, + String taskId, + Double rcfScore, + Double grade, + Double confidence, + List featureData, + Instant dataStartTime, + Instant dataEndTime, + Instant executionStartTime, + Instant executionEndTime, + String error, + Optional entity, + User user, + Integer schemaVersion, + String modelId, + double[] relevantAttribution, + Integer relativeIndex, + double[] pastValues, + double[][] expectedValuesList, + double[] likelihoodOfValues, + Double threshold + ) { + List convertedRelevantAttribution = null; + List convertedPastValuesList = null; + List convertedExpectedValues = null; + + if (grade > 0) { + int featureSize = featureData.size(); + if (relevantAttribution != null) { + if (relevantAttribution.length == featureSize) { + convertedRelevantAttribution = new ArrayList<>(featureSize); + for (int j = 0; j < featureSize; j++) { + convertedRelevantAttribution.add(new DataByFeatureId(featureData.get(j).getFeatureId(), relevantAttribution[j])); + } + } else { + LOG + .error( + new ParameterizedMessage( + "Attribution array size does not match. Expected [{}] but got [{}]", + featureSize, + relevantAttribution.length + ) + ); + } + } + + if (pastValues != null) { + if (pastValues.length == featureSize) { + convertedPastValuesList = new ArrayList<>(featureSize); + for (int j = 0; j < featureSize; j++) { + convertedPastValuesList.add(new DataByFeatureId(featureData.get(j).getFeatureId(), pastValues[j])); + } + } else { + LOG + .error( + new ParameterizedMessage( + "Past value array size does not match. Expected [{}] but got [{}]", + featureSize, + pastValues.length + ) + ); + } + } + + if (expectedValuesList != null && expectedValuesList.length > 0) { + int numberOfExpectedLists = expectedValuesList.length; + int numberOfExpectedVals = expectedValuesList[0].length; + if (numberOfExpectedVals == featureSize && likelihoodOfValues.length == numberOfExpectedLists) { + convertedExpectedValues = new ArrayList<>(numberOfExpectedLists); + for (int j = 0; j < numberOfExpectedLists; j++) { + List valueList = new ArrayList<>(featureSize); + for (int k = 0; k < featureSize; k++) { + valueList.add(new DataByFeatureId(featureData.get(k).getFeatureId(), expectedValuesList[j][k])); + } + convertedExpectedValues.add(new ExpectedValueList(likelihoodOfValues[j], valueList)); + } + } else if (numberOfExpectedVals != featureSize) { + LOG + .error( + new ParameterizedMessage( + "expected value array mismatch. Expected [{}] actual [{}].", + featureSize, + numberOfExpectedVals + ) + ); + } else { + LOG + .error( + new ParameterizedMessage( + "likelihood and expected array mismatch: Likelihood [{}] expected value [{}].", + likelihoodOfValues.length, + numberOfExpectedLists + ) + ); + } + } + } + + return new AnomalyResult( + detectorId, + taskId, + rcfScore, + grade, + confidence, + featureData, + dataStartTime, + dataEndTime, + executionStartTime, + executionEndTime, + error, + entity, + user, + schemaVersion, + modelId, + (relativeIndex == null || dataStartTime == null) + ? null + : Instant.ofEpochMilli(dataStartTime.toEpochMilli() + relativeIndex * intervalMillis), + convertedRelevantAttribution, + convertedPastValuesList, + convertedExpectedValues, + threshold + ); + } + + public AnomalyResult(StreamInput input) throws IOException { + super(input); + this.confidence = input.readDouble(); + this.anomalyScore = input.readDouble(); + this.anomalyGrade = input.readDouble(); + // if anomaly is caused by current input, we don't show approximate time + this.approxAnomalyStartTime = input.readOptionalInstant(); + + int attributeNumber = input.readVInt(); + if (attributeNumber <= 0) { + this.relevantAttribution = null; + } else { + this.relevantAttribution = new ArrayList<>(attributeNumber); + for (int i = 0; i < attributeNumber; i++) { + relevantAttribution.add(new DataByFeatureId(input)); + } + } + + int pastValueNumber = input.readVInt(); + if (pastValueNumber <= 0) { + this.pastValues = null; + } else { + this.pastValues = new ArrayList<>(pastValueNumber); + for (int i = 0; i < pastValueNumber; i++) { + pastValues.add(new DataByFeatureId(input)); + } + } + + int expectedValuesNumber = input.readVInt(); + if (expectedValuesNumber <= 0) { + this.expectedValuesList = null; + } else { + this.expectedValuesList = new ArrayList<>(); + for (int i = 0; i < expectedValuesNumber; i++) { + expectedValuesList.add(new ExpectedValueList(input)); + } + } + + this.threshold = input.readOptionalDouble(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder + .startObject() + .field(DETECTOR_ID_FIELD, configId) + .field(CommonName.SCHEMA_VERSION_FIELD, schemaVersion); + // In normal AD result, we always pass data start/end times. In custom result index, + // we need to write/delete a dummy AD result to verify if user has write permission + // to the custom result index. Just pass in null start/end time for this dummy anomaly + // result to make sure it won't be queried by mistake. + if (dataStartTime != null) { + xContentBuilder.field(CommonName.DATA_START_TIME_FIELD, dataStartTime.toEpochMilli()); + } + if (dataEndTime != null) { + xContentBuilder.field(CommonName.DATA_END_TIME_FIELD, dataEndTime.toEpochMilli()); + } + if (featureData != null) { + // can be null during preview + xContentBuilder.field(CommonName.FEATURE_DATA_FIELD, featureData.toArray()); + } + if (executionStartTime != null) { + // can be null during preview + xContentBuilder.field(CommonName.EXECUTION_START_TIME_FIELD, executionStartTime.toEpochMilli()); + } + if (executionEndTime != null) { + // can be null during preview + xContentBuilder.field(CommonName.EXECUTION_END_TIME_FIELD, executionEndTime.toEpochMilli()); + } + if (anomalyScore != null && !anomalyScore.isNaN()) { + xContentBuilder.field(ANOMALY_SCORE_FIELD, anomalyScore); + } + if (anomalyGrade != null && !anomalyGrade.isNaN()) { + xContentBuilder.field(ANOMALY_GRADE_FIELD, anomalyGrade); + } + if (confidence != null && !confidence.isNaN()) { + xContentBuilder.field(CommonName.CONFIDENCE_FIELD, confidence); + } + if (error != null) { + xContentBuilder.field(CommonName.ERROR_FIELD, error); + } + if (optionalEntity.isPresent()) { + xContentBuilder.field(CommonName.ENTITY_FIELD, optionalEntity.get()); + } + if (user != null) { + xContentBuilder.field(CommonName.USER_FIELD, user); + } + if (taskId != null) { + xContentBuilder.field(CommonName.TASK_ID_FIELD, taskId); + } + if (modelId != null) { + xContentBuilder.field(CommonName.MODEL_ID_FIELD, modelId); + } + + // output extra fields such as attribution and expected only when this is an anomaly + if (anomalyGrade != null && anomalyGrade > 0) { + if (approxAnomalyStartTime != null) { + xContentBuilder.field(APPROX_ANOMALY_START_FIELD, approxAnomalyStartTime.toEpochMilli()); + } + if (relevantAttribution != null) { + xContentBuilder.array(RELEVANT_ATTRIBUTION_FIELD, relevantAttribution.toArray()); + } + if (pastValues != null) { + xContentBuilder.array(PAST_VALUES_FIELD, pastValues.toArray()); + } + + if (expectedValuesList != null) { + xContentBuilder.array(EXPECTED_VALUES_FIELD, expectedValuesList.toArray()); + } + } + + if (threshold != null && !threshold.isNaN()) { + xContentBuilder.field(THRESHOLD_FIELD, threshold); + } + return xContentBuilder.endObject(); + } + + public static AnomalyResult parse(XContentParser parser) throws IOException { + String detectorId = null; + Double anomalyScore = null; + Double anomalyGrade = null; + Double confidence = null; + List featureData = new ArrayList<>(); + Instant dataStartTime = null; + Instant dataEndTime = null; + Instant executionStartTime = null; + Instant executionEndTime = null; + String error = null; + Entity entity = null; + User user = null; + Integer schemaVersion = CommonValue.NO_SCHEMA_VERSION; + String taskId = null; + String modelId = null; + Instant approAnomalyStartTime = null; + List relavantAttribution = new ArrayList<>(); + List pastValues = new ArrayList<>(); + List expectedValues = new ArrayList<>(); + Double threshold = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case DETECTOR_ID_FIELD: + detectorId = parser.text(); + break; + case ANOMALY_SCORE_FIELD: + anomalyScore = parser.doubleValue(); + break; + case ANOMALY_GRADE_FIELD: + anomalyGrade = parser.doubleValue(); + break; + case CommonName.CONFIDENCE_FIELD: + confidence = parser.doubleValue(); + break; + case CommonName.FEATURE_DATA_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + featureData.add(FeatureData.parse(parser)); + } + break; + case CommonName.DATA_START_TIME_FIELD: + dataStartTime = ParseUtils.toInstant(parser); + break; + case CommonName.DATA_END_TIME_FIELD: + dataEndTime = ParseUtils.toInstant(parser); + break; + case CommonName.EXECUTION_START_TIME_FIELD: + executionStartTime = ParseUtils.toInstant(parser); + break; + case CommonName.EXECUTION_END_TIME_FIELD: + executionEndTime = ParseUtils.toInstant(parser); + break; + case CommonName.ERROR_FIELD: + error = parser.text(); + break; + case CommonName.ENTITY_FIELD: + entity = Entity.parse(parser); + break; + case CommonName.USER_FIELD: + user = User.parse(parser); + break; + case CommonName.SCHEMA_VERSION_FIELD: + schemaVersion = parser.intValue(); + break; + case CommonName.TASK_ID_FIELD: + taskId = parser.text(); + break; + case CommonName.MODEL_ID_FIELD: + modelId = parser.text(); + break; + case APPROX_ANOMALY_START_FIELD: + approAnomalyStartTime = ParseUtils.toInstant(parser); + break; + case RELEVANT_ATTRIBUTION_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + relavantAttribution.add(DataByFeatureId.parse(parser)); + } + break; + case PAST_VALUES_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + pastValues.add(DataByFeatureId.parse(parser)); + } + break; + case EXPECTED_VALUES_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + expectedValues.add(ExpectedValueList.parse(parser)); + } + break; + case THRESHOLD_FIELD: + threshold = parser.doubleValue(); + break; + default: + parser.skipChildren(); + break; + } + } + + return new AnomalyResult( + detectorId, + taskId, + anomalyScore, + anomalyGrade, + confidence, + featureData, + dataStartTime, + dataEndTime, + executionStartTime, + executionEndTime, + error, + Optional.ofNullable(entity), + user, + schemaVersion, + modelId, + approAnomalyStartTime, + relavantAttribution, + pastValues, + expectedValues, + threshold + ); + } + + @Generated + @Override + public boolean equals(Object o) { + if (!super.equals(o)) + return false; + if (getClass() != o.getClass()) + return false; + AnomalyResult that = (AnomalyResult) o; + return Objects.equal(confidence, that.confidence) + && Objects.equal(anomalyScore, that.anomalyScore) + && Objects.equal(anomalyGrade, that.anomalyGrade) + && Objects.equal(approxAnomalyStartTime, that.approxAnomalyStartTime) + && Objects.equal(relevantAttribution, that.relevantAttribution) + && Objects.equal(pastValues, that.pastValues) + && Objects.equal(expectedValuesList, that.expectedValuesList) + && Objects.equal(threshold, that.threshold); + } + + @Generated + @Override + public int hashCode() { + final int prime = 31; + int result = super.hashCode(); + result = prime * result + Objects + .hashCode( + confidence, + anomalyScore, + anomalyGrade, + approxAnomalyStartTime, + relevantAttribution, + pastValues, + expectedValuesList, + threshold + ); + return result; + } + + @Generated + @Override + public String toString() { + return super.toString() + + ", " + + new ToStringBuilder(this) + .append("confidence", confidence) + .append("anomalyScore", anomalyScore) + .append("anomalyGrade", anomalyGrade) + .append("approAnomalyStartTime", approxAnomalyStartTime) + .append("relavantAttribution", relevantAttribution) + .append("pastValues", pastValues) + .append("expectedValuesList", StringUtils.join(expectedValuesList, "|")) + .append("threshold", threshold) + .toString(); + } + + public Double getConfidence() { + return confidence; + } + + public String getDetectorId() { + return configId; + } + + public Double getAnomalyScore() { + return anomalyScore; + } + + public Double getAnomalyGrade() { + return anomalyGrade; + } + + public Instant getApproAnomalyStartTime() { + return approxAnomalyStartTime; + } + + public List getRelavantAttribution() { + return relevantAttribution; + } + + public List getPastValues() { + return pastValues; + } + + public List getExpectedValuesList() { + return expectedValuesList; + } + + public Double getThreshold() { + return threshold; + } + + /** + * Anomaly result index consists of overwhelmingly (99.5%) zero-grade non-error documents. + * This function exclude the majority case. + * @return whether the anomaly result is important when the anomaly grade is not 0 + * or error is there. + */ + @Override + public boolean isHighPriority() { + // AnomalyResult.toXContent won't record Double.NaN and thus make it null + return (getAnomalyGrade() != null && getAnomalyGrade() > 0) || getError() != null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeDouble(confidence); + out.writeDouble(anomalyScore); + out.writeDouble(anomalyGrade); + + out.writeOptionalInstant(approxAnomalyStartTime); + + if (relevantAttribution != null) { + out.writeVInt(relevantAttribution.size()); + for (DataByFeatureId attribution : relevantAttribution) { + attribution.writeTo(out); + } + } else { + out.writeVInt(0); + } + + if (pastValues != null) { + out.writeVInt(pastValues.size()); + for (DataByFeatureId value : pastValues) { + value.writeTo(out); + } + } else { + out.writeVInt(0); + } + + if (expectedValuesList != null) { + out.writeVInt(expectedValuesList.size()); + for (ExpectedValueList value : expectedValuesList) { + value.writeTo(out); + } + } else { + out.writeVInt(0); + } + + out.writeOptionalDouble(threshold); + } + + public static AnomalyResult getDummyResult() { + return new AnomalyResult( + DUMMY_DETECTOR_ID, + null, + null, + null, + null, + null, + null, + null, + Optional.empty(), + null, + CommonValue.NO_SCHEMA_VERSION, + null + ); + } +} diff --git a/src/main/java/org/opensearch/ad/model/AnomalyResultBucket.java b/src/main/java/org/opensearch/ad/model/AnomalyResultBucket.java index 8f91f34b5..121d34f6d 100644 --- a/src/main/java/org/opensearch/ad/model/AnomalyResultBucket.java +++ b/src/main/java/org/opensearch/ad/model/AnomalyResultBucket.java @@ -15,9 +15,9 @@ import java.util.Map; import org.apache.commons.lang.builder.ToStringBuilder; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation.Bucket; diff --git a/src/main/java/org/opensearch/ad/model/AnomalyResultBucket.java-e b/src/main/java/org/opensearch/ad/model/AnomalyResultBucket.java-e new file mode 100644 index 000000000..121d34f6d --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/AnomalyResultBucket.java-e @@ -0,0 +1,118 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.io.IOException; +import java.util.Map; + +import org.apache.commons.lang.builder.ToStringBuilder; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation.Bucket; +import org.opensearch.search.aggregations.metrics.InternalMax; +import org.opensearch.timeseries.annotation.Generated; + +import com.google.common.base.Objects; + +/** + * Represents a single bucket when retrieving top anomaly results for HC detectors + */ +public class AnomalyResultBucket implements ToXContentObject, Writeable { + public static final String BUCKETS_FIELD = "buckets"; + public static final String KEY_FIELD = "key"; + public static final String DOC_COUNT_FIELD = "doc_count"; + public static final String MAX_ANOMALY_GRADE_FIELD = "max_anomaly_grade"; + + private final Map key; + private final int docCount; + private final double maxAnomalyGrade; + + public AnomalyResultBucket(Map key, int docCount, double maxAnomalyGrade) { + this.key = key; + this.docCount = docCount; + this.maxAnomalyGrade = maxAnomalyGrade; + } + + public AnomalyResultBucket(StreamInput input) throws IOException { + this.key = input.readMap(); + this.docCount = input.readInt(); + this.maxAnomalyGrade = input.readDouble(); + } + + public static AnomalyResultBucket createAnomalyResultBucket(Bucket bucket) { + return new AnomalyResultBucket( + bucket.getKey(), + (int) bucket.getDocCount(), + ((InternalMax) bucket.getAggregations().get(MAX_ANOMALY_GRADE_FIELD)).getValue() + ); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder + .startObject() + .field(KEY_FIELD, key) + .field(DOC_COUNT_FIELD, docCount) + .field(MAX_ANOMALY_GRADE_FIELD, maxAnomalyGrade); + return xContentBuilder.endObject(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(key); + out.writeInt(docCount); + out.writeDouble(maxAnomalyGrade); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + AnomalyResultBucket that = (AnomalyResultBucket) o; + return Objects.equal(getKey(), that.getKey()) + && Objects.equal(getDocCount(), that.getDocCount()) + && Objects.equal(getMaxAnomalyGrade(), that.getMaxAnomalyGrade()); + } + + @Generated + @Override + public int hashCode() { + return Objects.hashCode(getKey(), getDocCount(), getMaxAnomalyGrade()); + } + + @Generated + @Override + public String toString() { + return new ToStringBuilder(this) + .append("key", key) + .append("docCount", docCount) + .append("maxAnomalyGrade", maxAnomalyGrade) + .toString(); + } + + public Map getKey() { + return key; + } + + public int getDocCount() { + return docCount; + } + + public double getMaxAnomalyGrade() { + return maxAnomalyGrade; + } +} diff --git a/src/main/java/org/opensearch/ad/model/DetectorInternalState.java b/src/main/java/org/opensearch/ad/model/DetectorInternalState.java index 9b127e1ce..8630a7ac6 100644 --- a/src/main/java/org/opensearch/ad/model/DetectorInternalState.java +++ b/src/main/java/org/opensearch/ad/model/DetectorInternalState.java @@ -11,7 +11,7 @@ package org.opensearch.ad.model; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; import java.time.Instant; diff --git a/src/main/java/org/opensearch/ad/model/DetectorInternalState.java-e b/src/main/java/org/opensearch/ad/model/DetectorInternalState.java-e new file mode 100644 index 000000000..8630a7ac6 --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/DetectorInternalState.java-e @@ -0,0 +1,154 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.time.Instant; + +import org.opensearch.core.ParseField; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.annotation.Generated; +import org.opensearch.timeseries.util.ParseUtils; + +import com.google.common.base.Objects; + +/** + * Include anomaly detector's state + */ +public class DetectorInternalState implements ToXContentObject, Cloneable { + + public static final String PARSE_FIELD_NAME = "DetectorInternalState"; + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( + DetectorInternalState.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) + ); + + public static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; + public static final String ERROR_FIELD = "error"; + + private Instant lastUpdateTime = null; + private String error = null; + + private DetectorInternalState() {} + + public static class Builder { + private Instant lastUpdateTime = null; + private String error = null; + + public Builder() {} + + public Builder lastUpdateTime(Instant lastUpdateTime) { + this.lastUpdateTime = lastUpdateTime; + return this; + } + + public Builder error(String error) { + this.error = error; + return this; + } + + public DetectorInternalState build() { + DetectorInternalState state = new DetectorInternalState(); + state.lastUpdateTime = this.lastUpdateTime; + state.error = this.error; + + return state; + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + + if (lastUpdateTime != null) { + xContentBuilder.field(LAST_UPDATE_TIME_FIELD, lastUpdateTime.toEpochMilli()); + } + if (error != null) { + xContentBuilder.field(ERROR_FIELD, error); + } + return xContentBuilder.endObject(); + } + + public static DetectorInternalState parse(XContentParser parser) throws IOException { + Instant lastUpdateTime = null; + String error = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case LAST_UPDATE_TIME_FIELD: + lastUpdateTime = ParseUtils.toInstant(parser); + break; + case ERROR_FIELD: + error = parser.text(); + break; + default: + parser.skipChildren(); + break; + } + } + return new DetectorInternalState.Builder().lastUpdateTime(lastUpdateTime).error(error).build(); + } + + @Generated + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + DetectorInternalState that = (DetectorInternalState) o; + return Objects.equal(getLastUpdateTime(), that.getLastUpdateTime()) && Objects.equal(getError(), that.getError()); + } + + @Generated + @Override + public int hashCode() { + return Objects.hashCode(lastUpdateTime, error); + } + + @Override + public Object clone() { + DetectorInternalState state = null; + try { + state = (DetectorInternalState) super.clone(); + } catch (CloneNotSupportedException e) { + state = new DetectorInternalState.Builder().lastUpdateTime(lastUpdateTime).error(error).build(); + } + return state; + } + + public Instant getLastUpdateTime() { + return lastUpdateTime; + } + + public void setLastUpdateTime(Instant lastUpdateTime) { + this.lastUpdateTime = lastUpdateTime; + } + + public String getError() { + return error; + } + + public void setError(String error) { + this.error = error; + } +} diff --git a/src/main/java/org/opensearch/ad/model/DetectorProfile.java b/src/main/java/org/opensearch/ad/model/DetectorProfile.java index 3bbd2558b..77418552e 100644 --- a/src/main/java/org/opensearch/ad/model/DetectorProfile.java +++ b/src/main/java/org/opensearch/ad/model/DetectorProfile.java @@ -17,9 +17,9 @@ import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.commons.lang.builder.ToStringBuilder; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/model/DetectorProfile.java-e b/src/main/java/org/opensearch/ad/model/DetectorProfile.java-e new file mode 100644 index 000000000..77418552e --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/DetectorProfile.java-e @@ -0,0 +1,465 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.io.IOException; + +import org.apache.commons.lang.builder.EqualsBuilder; +import org.apache.commons.lang.builder.HashCodeBuilder; +import org.apache.commons.lang.builder.ToStringBuilder; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +public class DetectorProfile implements Writeable, ToXContentObject, Mergeable { + private DetectorState state; + private String error; + private ModelProfileOnNode[] modelProfile; + private int shingleSize; + private String coordinatingNode; + private long totalSizeInBytes; + private InitProgressProfile initProgress; + private Long totalEntities; + private Long activeEntities; + private ADTaskProfile adTaskProfile; + private long modelCount; + + public XContentBuilder toXContent(XContentBuilder builder) throws IOException { + return toXContent(builder, ToXContent.EMPTY_PARAMS); + } + + public DetectorProfile(StreamInput in) throws IOException { + if (in.readBoolean()) { + this.state = in.readEnum(DetectorState.class); + } + + this.error = in.readOptionalString(); + this.modelProfile = in.readOptionalArray(ModelProfileOnNode::new, ModelProfileOnNode[]::new); + this.shingleSize = in.readOptionalInt(); + this.coordinatingNode = in.readOptionalString(); + this.totalSizeInBytes = in.readOptionalLong(); + this.totalEntities = in.readOptionalLong(); + this.activeEntities = in.readOptionalLong(); + if (in.readBoolean()) { + this.initProgress = new InitProgressProfile(in); + } + if (in.readBoolean()) { + this.adTaskProfile = new ADTaskProfile(in); + } + this.modelCount = in.readVLong(); + } + + private DetectorProfile() {} + + public static class Builder { + private DetectorState state = null; + private String error = null; + private ModelProfileOnNode[] modelProfile = null; + private int shingleSize = -1; + private String coordinatingNode = null; + private long totalSizeInBytes = -1; + private InitProgressProfile initProgress = null; + private Long totalEntities; + private Long activeEntities; + private ADTaskProfile adTaskProfile; + private long modelCount = 0; + + public Builder() {} + + public Builder state(DetectorState state) { + this.state = state; + return this; + } + + public Builder error(String error) { + this.error = error; + return this; + } + + public Builder modelProfile(ModelProfileOnNode[] modelProfile) { + this.modelProfile = modelProfile; + return this; + } + + public Builder modelCount(long modelCount) { + this.modelCount = modelCount; + return this; + } + + public Builder shingleSize(int shingleSize) { + this.shingleSize = shingleSize; + return this; + } + + public Builder coordinatingNode(String coordinatingNode) { + this.coordinatingNode = coordinatingNode; + return this; + } + + public Builder totalSizeInBytes(long totalSizeInBytes) { + this.totalSizeInBytes = totalSizeInBytes; + return this; + } + + public Builder initProgress(InitProgressProfile initProgress) { + this.initProgress = initProgress; + return this; + } + + public Builder totalEntities(Long totalEntities) { + this.totalEntities = totalEntities; + return this; + } + + public Builder activeEntities(Long activeEntities) { + this.activeEntities = activeEntities; + return this; + } + + public Builder adTaskProfile(ADTaskProfile adTaskProfile) { + this.adTaskProfile = adTaskProfile; + return this; + } + + public DetectorProfile build() { + DetectorProfile profile = new DetectorProfile(); + profile.state = this.state; + profile.error = this.error; + profile.modelProfile = modelProfile; + profile.modelCount = modelCount; + profile.shingleSize = shingleSize; + profile.coordinatingNode = coordinatingNode; + profile.totalSizeInBytes = totalSizeInBytes; + profile.initProgress = initProgress; + profile.totalEntities = totalEntities; + profile.activeEntities = activeEntities; + profile.adTaskProfile = adTaskProfile; + + return profile; + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (state == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeEnum(state); + } + + out.writeOptionalString(error); + out.writeOptionalArray(modelProfile); + out.writeOptionalInt(shingleSize); + out.writeOptionalString(coordinatingNode); + out.writeOptionalLong(totalSizeInBytes); + out.writeOptionalLong(totalEntities); + out.writeOptionalLong(activeEntities); + if (initProgress == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + initProgress.writeTo(out); + } + if (adTaskProfile == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + adTaskProfile.writeTo(out); + } + out.writeVLong(modelCount); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + + if (state != null) { + xContentBuilder.field(ADCommonName.STATE, state); + } + if (error != null) { + xContentBuilder.field(ADCommonName.ERROR, error); + } + if (modelProfile != null && modelProfile.length > 0) { + xContentBuilder.startArray(ADCommonName.MODELS); + for (ModelProfileOnNode profile : modelProfile) { + profile.toXContent(xContentBuilder, params); + } + xContentBuilder.endArray(); + } + if (shingleSize != -1) { + xContentBuilder.field(ADCommonName.SHINGLE_SIZE, shingleSize); + } + if (coordinatingNode != null && !coordinatingNode.isEmpty()) { + xContentBuilder.field(ADCommonName.COORDINATING_NODE, coordinatingNode); + } + if (totalSizeInBytes != -1) { + xContentBuilder.field(ADCommonName.TOTAL_SIZE_IN_BYTES, totalSizeInBytes); + } + if (initProgress != null) { + xContentBuilder.field(ADCommonName.INIT_PROGRESS, initProgress); + } + if (totalEntities != null) { + xContentBuilder.field(ADCommonName.TOTAL_ENTITIES, totalEntities); + } + if (activeEntities != null) { + xContentBuilder.field(ADCommonName.ACTIVE_ENTITIES, activeEntities); + } + if (adTaskProfile != null) { + xContentBuilder.field(ADCommonName.AD_TASK, adTaskProfile); + } + if (modelCount > 0) { + xContentBuilder.field(ADCommonName.MODEL_COUNT, modelCount); + } + return xContentBuilder.endObject(); + } + + public DetectorState getState() { + return state; + } + + public void setState(DetectorState state) { + this.state = state; + } + + public String getError() { + return error; + } + + public void setError(String error) { + this.error = error; + } + + public ModelProfileOnNode[] getModelProfile() { + return modelProfile; + } + + public void setModelProfile(ModelProfileOnNode[] modelProfile) { + this.modelProfile = modelProfile; + } + + public int getShingleSize() { + return shingleSize; + } + + public void setShingleSize(int shingleSize) { + this.shingleSize = shingleSize; + } + + public String getCoordinatingNode() { + return coordinatingNode; + } + + public void setCoordinatingNode(String coordinatingNode) { + this.coordinatingNode = coordinatingNode; + } + + public long getTotalSizeInBytes() { + return totalSizeInBytes; + } + + public void setTotalSizeInBytes(long totalSizeInBytes) { + this.totalSizeInBytes = totalSizeInBytes; + } + + public InitProgressProfile getInitProgress() { + return initProgress; + } + + public void setInitProgress(InitProgressProfile initProgress) { + this.initProgress = initProgress; + } + + public Long getTotalEntities() { + return totalEntities; + } + + public void setTotalEntities(Long totalEntities) { + this.totalEntities = totalEntities; + } + + public Long getActiveEntities() { + return activeEntities; + } + + public void setActiveEntities(Long activeEntities) { + this.activeEntities = activeEntities; + } + + public ADTaskProfile getAdTaskProfile() { + return adTaskProfile; + } + + public void setAdTaskProfile(ADTaskProfile adTaskProfile) { + this.adTaskProfile = adTaskProfile; + } + + public long getModelCount() { + return modelCount; + } + + public void setModelCount(long modelCount) { + this.modelCount = modelCount; + } + + @Override + public void merge(Mergeable other) { + if (this == other || other == null || getClass() != other.getClass()) { + return; + } + DetectorProfile otherProfile = (DetectorProfile) other; + if (otherProfile.getState() != null) { + this.state = otherProfile.getState(); + } + if (otherProfile.getError() != null) { + this.error = otherProfile.getError(); + } + if (otherProfile.getCoordinatingNode() != null) { + this.coordinatingNode = otherProfile.getCoordinatingNode(); + } + if (otherProfile.getShingleSize() != -1) { + this.shingleSize = otherProfile.getShingleSize(); + } + if (otherProfile.getModelProfile() != null) { + this.modelProfile = otherProfile.getModelProfile(); + } + if (otherProfile.getTotalSizeInBytes() != -1) { + this.totalSizeInBytes = otherProfile.getTotalSizeInBytes(); + } + if (otherProfile.getInitProgress() != null) { + this.initProgress = otherProfile.getInitProgress(); + } + if (otherProfile.getTotalEntities() != null) { + this.totalEntities = otherProfile.getTotalEntities(); + } + if (otherProfile.getActiveEntities() != null) { + this.activeEntities = otherProfile.getActiveEntities(); + } + if (otherProfile.getAdTaskProfile() != null) { + this.adTaskProfile = otherProfile.getAdTaskProfile(); + } + if (otherProfile.getModelCount() > 0) { + this.modelCount = otherProfile.getModelCount(); + } + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + if (obj instanceof DetectorProfile) { + DetectorProfile other = (DetectorProfile) obj; + + EqualsBuilder equalsBuilder = new EqualsBuilder(); + if (state != null) { + equalsBuilder.append(state, other.state); + } + if (error != null) { + equalsBuilder.append(error, other.error); + } + if (modelProfile != null && modelProfile.length > 0) { + equalsBuilder.append(modelProfile, other.modelProfile); + } + if (shingleSize != -1) { + equalsBuilder.append(shingleSize, other.shingleSize); + } + if (coordinatingNode != null) { + equalsBuilder.append(coordinatingNode, other.coordinatingNode); + } + if (totalSizeInBytes != -1) { + equalsBuilder.append(totalSizeInBytes, other.totalSizeInBytes); + } + if (initProgress != null) { + equalsBuilder.append(initProgress, other.initProgress); + } + if (totalEntities != null) { + equalsBuilder.append(totalEntities, other.totalEntities); + } + if (activeEntities != null) { + equalsBuilder.append(activeEntities, other.activeEntities); + } + if (adTaskProfile != null) { + equalsBuilder.append(adTaskProfile, other.adTaskProfile); + } + if (modelCount > 0) { + equalsBuilder.append(modelCount, other.modelCount); + } + return equalsBuilder.isEquals(); + } + return false; + } + + @Override + public int hashCode() { + return new HashCodeBuilder() + .append(state) + .append(error) + .append(modelProfile) + .append(shingleSize) + .append(coordinatingNode) + .append(totalSizeInBytes) + .append(initProgress) + .append(totalEntities) + .append(activeEntities) + .append(adTaskProfile) + .append(modelCount) + .toHashCode(); + } + + @Override + public String toString() { + ToStringBuilder toStringBuilder = new ToStringBuilder(this); + + if (state != null) { + toStringBuilder.append(ADCommonName.STATE, state); + } + if (error != null) { + toStringBuilder.append(ADCommonName.ERROR, error); + } + if (modelProfile != null && modelProfile.length > 0) { + toStringBuilder.append(modelProfile); + } + if (shingleSize != -1) { + toStringBuilder.append(ADCommonName.SHINGLE_SIZE, shingleSize); + } + if (coordinatingNode != null) { + toStringBuilder.append(ADCommonName.COORDINATING_NODE, coordinatingNode); + } + if (totalSizeInBytes != -1) { + toStringBuilder.append(ADCommonName.TOTAL_SIZE_IN_BYTES, totalSizeInBytes); + } + if (initProgress != null) { + toStringBuilder.append(ADCommonName.INIT_PROGRESS, initProgress); + } + if (totalEntities != null) { + toStringBuilder.append(ADCommonName.TOTAL_ENTITIES, totalEntities); + } + if (activeEntities != null) { + toStringBuilder.append(ADCommonName.ACTIVE_ENTITIES, activeEntities); + } + if (adTaskProfile != null) { + toStringBuilder.append(ADCommonName.AD_TASK, adTaskProfile); + } + if (modelCount > 0) { + toStringBuilder.append(ADCommonName.MODEL_COUNT, modelCount); + } + return toStringBuilder.toString(); + } +} diff --git a/src/main/java/org/opensearch/ad/model/DetectorProfileName.java-e b/src/main/java/org/opensearch/ad/model/DetectorProfileName.java-e new file mode 100644 index 000000000..443066ac8 --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/DetectorProfileName.java-e @@ -0,0 +1,79 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.util.Collection; +import java.util.Set; + +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.timeseries.Name; + +public enum DetectorProfileName implements Name { + STATE(ADCommonName.STATE), + ERROR(ADCommonName.ERROR), + COORDINATING_NODE(ADCommonName.COORDINATING_NODE), + SHINGLE_SIZE(ADCommonName.SHINGLE_SIZE), + TOTAL_SIZE_IN_BYTES(ADCommonName.TOTAL_SIZE_IN_BYTES), + MODELS(ADCommonName.MODELS), + INIT_PROGRESS(ADCommonName.INIT_PROGRESS), + TOTAL_ENTITIES(ADCommonName.TOTAL_ENTITIES), + ACTIVE_ENTITIES(ADCommonName.ACTIVE_ENTITIES), + AD_TASK(ADCommonName.AD_TASK); + + private String name; + + DetectorProfileName(String name) { + this.name = name; + } + + /** + * Get profile name + * + * @return name + */ + @Override + public String getName() { + return name; + } + + public static DetectorProfileName getName(String name) { + switch (name) { + case ADCommonName.STATE: + return STATE; + case ADCommonName.ERROR: + return ERROR; + case ADCommonName.COORDINATING_NODE: + return COORDINATING_NODE; + case ADCommonName.SHINGLE_SIZE: + return SHINGLE_SIZE; + case ADCommonName.TOTAL_SIZE_IN_BYTES: + return TOTAL_SIZE_IN_BYTES; + case ADCommonName.MODELS: + return MODELS; + case ADCommonName.INIT_PROGRESS: + return INIT_PROGRESS; + case ADCommonName.TOTAL_ENTITIES: + return TOTAL_ENTITIES; + case ADCommonName.ACTIVE_ENTITIES: + return ACTIVE_ENTITIES; + case ADCommonName.AD_TASK: + return AD_TASK; + default: + throw new IllegalArgumentException(ADCommonMessages.UNSUPPORTED_PROFILE_TYPE); + } + } + + public static Set getNames(Collection names) { + return Name.getNameFromCollection(names, DetectorProfileName::getName); + } +} diff --git a/src/main/java/org/opensearch/ad/model/DetectorState.java-e b/src/main/java/org/opensearch/ad/model/DetectorState.java-e new file mode 100644 index 000000000..a4959417b --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/DetectorState.java-e @@ -0,0 +1,18 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +public enum DetectorState { + DISABLED, + INIT, + RUNNING +} diff --git a/src/main/java/org/opensearch/ad/model/DetectorValidationIssue.java b/src/main/java/org/opensearch/ad/model/DetectorValidationIssue.java index 41a68263e..48586e7f8 100644 --- a/src/main/java/org/opensearch/ad/model/DetectorValidationIssue.java +++ b/src/main/java/org/opensearch/ad/model/DetectorValidationIssue.java @@ -14,9 +14,9 @@ import java.io.IOException; import java.util.Map; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.model.IntervalTimeConfiguration; diff --git a/src/main/java/org/opensearch/ad/model/DetectorValidationIssue.java-e b/src/main/java/org/opensearch/ad/model/DetectorValidationIssue.java-e new file mode 100644 index 000000000..48586e7f8 --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/DetectorValidationIssue.java-e @@ -0,0 +1,154 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; + +import com.google.common.base.Objects; + +/** + * DetectorValidationIssue is a single validation issue found for detector. + * + * For example, if detector's multiple features are using wrong type field or non existing field + * the issue would be in `detector` aspect, not `model`; + * and its type is FEATURE_ATTRIBUTES, because it is related to feature; + * message would be the message from thrown exception; + * subIssues are issues for each feature; + * suggestion is how to fix the issue/subIssues found + */ +public class DetectorValidationIssue implements ToXContentObject, Writeable { + private static final String MESSAGE_FIELD = "message"; + private static final String SUGGESTED_FIELD_NAME = "suggested_value"; + private static final String SUB_ISSUES_FIELD_NAME = "sub_issues"; + + private final ValidationAspect aspect; + private final ValidationIssueType type; + private final String message; + private Map subIssues; + private IntervalTimeConfiguration intervalSuggestion; + + public ValidationAspect getAspect() { + return aspect; + } + + public ValidationIssueType getType() { + return type; + } + + public String getMessage() { + return message; + } + + public Map getSubIssues() { + return subIssues; + } + + public IntervalTimeConfiguration getIntervalSuggestion() { + return intervalSuggestion; + } + + public DetectorValidationIssue( + ValidationAspect aspect, + ValidationIssueType type, + String message, + Map subIssues, + IntervalTimeConfiguration intervalSuggestion + ) { + this.aspect = aspect; + this.type = type; + this.message = message; + this.subIssues = subIssues; + this.intervalSuggestion = intervalSuggestion; + } + + public DetectorValidationIssue(ValidationAspect aspect, ValidationIssueType type, String message) { + this(aspect, type, message, null, null); + } + + public DetectorValidationIssue(StreamInput input) throws IOException { + aspect = input.readEnum(ValidationAspect.class); + type = input.readEnum(ValidationIssueType.class); + message = input.readString(); + if (input.readBoolean()) { + subIssues = input.readMap(StreamInput::readString, StreamInput::readString); + } + if (input.readBoolean()) { + intervalSuggestion = IntervalTimeConfiguration.readFrom(input); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeEnum(aspect); + out.writeEnum(type); + out.writeString(message); + if (subIssues != null && !subIssues.isEmpty()) { + out.writeBoolean(true); + out.writeMap(subIssues, StreamOutput::writeString, StreamOutput::writeString); + } else { + out.writeBoolean(false); + } + if (intervalSuggestion != null) { + out.writeBoolean(true); + intervalSuggestion.writeTo(out); + } else { + out.writeBoolean(false); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject().startObject(type.getName()); + xContentBuilder.field(MESSAGE_FIELD, message); + if (subIssues != null) { + XContentBuilder subIssuesBuilder = xContentBuilder.startObject(SUB_ISSUES_FIELD_NAME); + for (Map.Entry entry : subIssues.entrySet()) { + subIssuesBuilder.field(entry.getKey(), entry.getValue()); + } + subIssuesBuilder.endObject(); + } + if (intervalSuggestion != null) { + xContentBuilder.field(SUGGESTED_FIELD_NAME, intervalSuggestion); + } + + return xContentBuilder.endObject().endObject(); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + DetectorValidationIssue anotherIssue = (DetectorValidationIssue) o; + return Objects.equal(getAspect(), anotherIssue.getAspect()) + && Objects.equal(getMessage(), anotherIssue.getMessage()) + && Objects.equal(getSubIssues(), anotherIssue.getSubIssues()) + && Objects.equal(getIntervalSuggestion(), anotherIssue.getIntervalSuggestion()) + && Objects.equal(getType(), anotherIssue.getType()); + } + + @Override + public int hashCode() { + return Objects.hashCode(aspect, message, subIssues, subIssues, type); + } +} diff --git a/src/main/java/org/opensearch/ad/model/EntityAnomalyResult.java-e b/src/main/java/org/opensearch/ad/model/EntityAnomalyResult.java-e new file mode 100644 index 000000000..7eeb02e6c --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/EntityAnomalyResult.java-e @@ -0,0 +1,38 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.util.List; + +public class EntityAnomalyResult implements Mergeable { + + private List anomalyResults; + + public EntityAnomalyResult(List anomalyResults) { + this.anomalyResults = anomalyResults; + } + + public List getAnomalyResults() { + return anomalyResults; + } + + @Override + public void merge(Mergeable other) { + if (this == other || other == null || getClass() != other.getClass()) { + return; + } + EntityAnomalyResult otherEntityAnomalyResult = (EntityAnomalyResult) other; + if (otherEntityAnomalyResult.getAnomalyResults() != null) { + this.anomalyResults.addAll(otherEntityAnomalyResult.getAnomalyResults()); + } + } +} diff --git a/src/main/java/org/opensearch/ad/model/EntityProfile.java b/src/main/java/org/opensearch/ad/model/EntityProfile.java index 2dfd91226..4f2306e96 100644 --- a/src/main/java/org/opensearch/ad/model/EntityProfile.java +++ b/src/main/java/org/opensearch/ad/model/EntityProfile.java @@ -18,9 +18,9 @@ import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.commons.lang.builder.ToStringBuilder; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/model/EntityProfile.java-e b/src/main/java/org/opensearch/ad/model/EntityProfile.java-e new file mode 100644 index 000000000..4f2306e96 --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/EntityProfile.java-e @@ -0,0 +1,288 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.io.IOException; +import java.util.Optional; + +import org.apache.commons.lang.builder.EqualsBuilder; +import org.apache.commons.lang.builder.HashCodeBuilder; +import org.apache.commons.lang.builder.ToStringBuilder; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; + +/** + * Profile output for detector entity. + */ +public class EntityProfile implements Writeable, ToXContent, Mergeable { + // field name in toXContent + public static final String IS_ACTIVE = "is_active"; + public static final String LAST_ACTIVE_TIMESTAMP = "last_active_timestamp"; + public static final String LAST_SAMPLE_TIMESTAMP = "last_sample_timestamp"; + + private Boolean isActive; + private long lastActiveTimestampMs; + private long lastSampleTimestampMs; + private InitProgressProfile initProgress; + private ModelProfileOnNode modelProfile; + private EntityState state; + + public EntityProfile( + Boolean isActive, + long lastActiveTimeStamp, + long lastSampleTimestamp, + InitProgressProfile initProgress, + ModelProfileOnNode modelProfile, + EntityState state + ) { + super(); + this.isActive = isActive; + this.lastActiveTimestampMs = lastActiveTimeStamp; + this.lastSampleTimestampMs = lastSampleTimestamp; + this.initProgress = initProgress; + this.modelProfile = modelProfile; + this.state = state; + } + + public static class Builder { + private Boolean isActive = null; + private long lastActiveTimestampMs = -1L; + private long lastSampleTimestampMs = -1L; + private InitProgressProfile initProgress = null; + private ModelProfileOnNode modelProfile = null; + private EntityState state = EntityState.UNKNOWN; + + public Builder isActive(Boolean isActive) { + this.isActive = isActive; + return this; + } + + public Builder lastActiveTimestampMs(long lastActiveTimestampMs) { + this.lastActiveTimestampMs = lastActiveTimestampMs; + return this; + } + + public Builder lastSampleTimestampMs(long lastSampleTimestampMs) { + this.lastSampleTimestampMs = lastSampleTimestampMs; + return this; + } + + public Builder initProgress(InitProgressProfile initProgress) { + this.initProgress = initProgress; + return this; + } + + public Builder modelProfile(ModelProfileOnNode modelProfile) { + this.modelProfile = modelProfile; + return this; + } + + public Builder state(EntityState state) { + this.state = state; + return this; + } + + public EntityProfile build() { + return new EntityProfile(isActive, lastActiveTimestampMs, lastSampleTimestampMs, initProgress, modelProfile, state); + } + } + + public EntityProfile(StreamInput in) throws IOException { + this.isActive = in.readOptionalBoolean(); + this.lastActiveTimestampMs = in.readLong(); + this.lastSampleTimestampMs = in.readLong(); + if (in.readBoolean()) { + this.initProgress = new InitProgressProfile(in); + } + if (in.readBoolean()) { + this.modelProfile = new ModelProfileOnNode(in); + } + this.state = in.readEnum(EntityState.class); + } + + public Optional getActive() { + return Optional.ofNullable(isActive); + } + + /** + * Return the last active time of an entity's state. + * + * If the entity's state is active in the cache, the value indicates when the cache + * is lastly accessed (get/put). If the entity's state is inactive in the cache, + * the value indicates when the cache state is created or when the entity is evicted + * from active entity cache. + * + * @return the last active time of an entity's state + */ + public Long getLastActiveTimestamp() { + return lastActiveTimestampMs; + } + + /** + * + * @return last document's timestamp belonging to an entity + */ + public Long getLastSampleTimestamp() { + return lastSampleTimestampMs; + } + + public InitProgressProfile getInitProgress() { + return initProgress; + } + + public ModelProfileOnNode getModelProfile() { + return modelProfile; + } + + public EntityState getState() { + return state; + } + + public XContentBuilder toXContent(XContentBuilder builder) throws IOException { + return toXContent(builder, ToXContent.EMPTY_PARAMS); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (isActive != null) { + builder.field(IS_ACTIVE, isActive); + } + if (lastActiveTimestampMs > 0) { + builder.field(LAST_ACTIVE_TIMESTAMP, lastActiveTimestampMs); + } + if (lastSampleTimestampMs > 0) { + builder.field(LAST_SAMPLE_TIMESTAMP, lastSampleTimestampMs); + } + if (initProgress != null) { + builder.field(ADCommonName.INIT_PROGRESS, initProgress); + } + if (modelProfile != null) { + builder.field(ADCommonName.MODEL, modelProfile); + } + if (state != null && state != EntityState.UNKNOWN) { + builder.field(ADCommonName.STATE, state); + } + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalBoolean(isActive); + out.writeLong(lastActiveTimestampMs); + out.writeLong(lastSampleTimestampMs); + if (initProgress != null) { + out.writeBoolean(true); + initProgress.writeTo(out); + } else { + out.writeBoolean(false); + } + if (modelProfile != null) { + out.writeBoolean(true); + modelProfile.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeEnum(state); + } + + @Override + public String toString() { + ToStringBuilder builder = new ToStringBuilder(this); + if (isActive != null) { + builder.append(IS_ACTIVE, isActive); + } + if (lastActiveTimestampMs > 0) { + builder.append(LAST_ACTIVE_TIMESTAMP, lastActiveTimestampMs); + } + if (lastSampleTimestampMs > 0) { + builder.append(LAST_SAMPLE_TIMESTAMP, lastSampleTimestampMs); + } + if (initProgress != null) { + builder.append(ADCommonName.INIT_PROGRESS, initProgress); + } + if (modelProfile != null) { + builder.append(ADCommonName.MODELS, modelProfile); + } + if (state != null && state != EntityState.UNKNOWN) { + builder.append(ADCommonName.STATE, state); + } + return builder.toString(); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + if (obj instanceof EntityProfile) { + EntityProfile other = (EntityProfile) obj; + EqualsBuilder equalsBuilder = new EqualsBuilder(); + equalsBuilder.append(isActive, other.isActive); + equalsBuilder.append(lastActiveTimestampMs, other.lastActiveTimestampMs); + equalsBuilder.append(lastSampleTimestampMs, other.lastSampleTimestampMs); + equalsBuilder.append(initProgress, other.initProgress); + equalsBuilder.append(modelProfile, other.modelProfile); + equalsBuilder.append(state, other.state); + + return equalsBuilder.isEquals(); + } + return false; + } + + @Override + public int hashCode() { + return new HashCodeBuilder() + .append(isActive) + .append(lastActiveTimestampMs) + .append(lastSampleTimestampMs) + .append(initProgress) + .append(modelProfile) + .append(state) + .toHashCode(); + } + + @Override + public void merge(Mergeable other) { + if (this == other || other == null || getClass() != other.getClass()) { + return; + } + EntityProfile otherProfile = (EntityProfile) other; + + if (otherProfile.getInitProgress() != null) { + this.initProgress = otherProfile.getInitProgress(); + } + if (otherProfile.isActive != null) { + this.isActive = otherProfile.isActive; + } + if (otherProfile.lastActiveTimestampMs > 0) { + this.lastActiveTimestampMs = otherProfile.lastActiveTimestampMs; + } + if (otherProfile.lastSampleTimestampMs > 0) { + this.lastSampleTimestampMs = otherProfile.lastSampleTimestampMs; + } + if (otherProfile.modelProfile != null) { + this.modelProfile = otherProfile.modelProfile; + } + if (otherProfile.getState() != null && otherProfile.getState() != EntityState.UNKNOWN) { + this.state = otherProfile.getState(); + } + } +} diff --git a/src/main/java/org/opensearch/ad/model/EntityProfileName.java-e b/src/main/java/org/opensearch/ad/model/EntityProfileName.java-e new file mode 100644 index 000000000..84fd92987 --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/EntityProfileName.java-e @@ -0,0 +1,61 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.util.Collection; +import java.util.Set; + +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.timeseries.Name; + +public enum EntityProfileName implements Name { + INIT_PROGRESS(ADCommonName.INIT_PROGRESS), + ENTITY_INFO(ADCommonName.ENTITY_INFO), + STATE(ADCommonName.STATE), + MODELS(ADCommonName.MODELS); + + private String name; + + EntityProfileName(String name) { + this.name = name; + } + + /** + * Get profile name + * + * @return name + */ + @Override + public String getName() { + return name; + } + + public static EntityProfileName getName(String name) { + switch (name) { + case ADCommonName.INIT_PROGRESS: + return INIT_PROGRESS; + case ADCommonName.ENTITY_INFO: + return ENTITY_INFO; + case ADCommonName.STATE: + return STATE; + case ADCommonName.MODELS: + return MODELS; + default: + throw new IllegalArgumentException(ADCommonMessages.UNSUPPORTED_PROFILE_TYPE); + } + } + + public static Set getNames(Collection names) { + return Name.getNameFromCollection(names, EntityProfileName::getName); + } +} diff --git a/src/main/java/org/opensearch/ad/model/EntityState.java-e b/src/main/java/org/opensearch/ad/model/EntityState.java-e new file mode 100644 index 000000000..1e0d05d8e --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/EntityState.java-e @@ -0,0 +1,18 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +public enum EntityState { + UNKNOWN, + INIT, + RUNNING +} diff --git a/src/main/java/org/opensearch/ad/model/ExpectedValueList.java b/src/main/java/org/opensearch/ad/model/ExpectedValueList.java index bad7e956b..14abc4cc6 100644 --- a/src/main/java/org/opensearch/ad/model/ExpectedValueList.java +++ b/src/main/java/org/opensearch/ad/model/ExpectedValueList.java @@ -11,7 +11,7 @@ package org.opensearch.ad.model; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; import java.util.ArrayList; @@ -19,9 +19,9 @@ import org.apache.commons.lang.builder.ToStringBuilder; import org.apache.commons.lang3.StringUtils; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; diff --git a/src/main/java/org/opensearch/ad/model/ExpectedValueList.java-e b/src/main/java/org/opensearch/ad/model/ExpectedValueList.java-e new file mode 100644 index 000000000..14abc4cc6 --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/ExpectedValueList.java-e @@ -0,0 +1,121 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.apache.commons.lang.builder.ToStringBuilder; +import org.apache.commons.lang3.StringUtils; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.DataByFeatureId; + +import com.google.common.base.Objects; + +public class ExpectedValueList implements ToXContentObject, Writeable { + public static final String LIKELIHOOD_FIELD = "likelihood"; + private Double likelihood; + private List valueList; + + public ExpectedValueList(Double likelihood, List valueList) { + this.likelihood = likelihood; + this.valueList = valueList; + } + + public ExpectedValueList(StreamInput input) throws IOException { + this.likelihood = input.readOptionalDouble(); + this.valueList = input.readList(DataByFeatureId::new); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + if (likelihood != null) { + xContentBuilder.field(LIKELIHOOD_FIELD, likelihood); + } + if (valueList != null) { + xContentBuilder.field(CommonName.VALUE_LIST_FIELD, valueList.toArray()); + } + return xContentBuilder.endObject(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalDouble(likelihood); + out.writeList(valueList); + } + + public static ExpectedValueList parse(XContentParser parser) throws IOException { + Double likelihood = null; + List valueList = new ArrayList<>(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case LIKELIHOOD_FIELD: + likelihood = parser.doubleValue(); + break; + case CommonName.VALUE_LIST_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + valueList.add(DataByFeatureId.parse(parser)); + } + break; + default: + // the unknown field and it's children should be ignored + parser.skipChildren(); + break; + } + } + + return new ExpectedValueList(likelihood, valueList); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + ExpectedValueList that = (ExpectedValueList) o; + return Double.compare(likelihood, that.likelihood) == 0 && Objects.equal(valueList, that.valueList); + } + + @Override + public int hashCode() { + return Objects.hashCode(likelihood, valueList); + } + + @Override + public String toString() { + return new ToStringBuilder(this).append("likelihood", likelihood).append("valueList", StringUtils.join(valueList, "|")).toString(); + } + + public Double getLikelihood() { + return likelihood; + } + + public List getValueList() { + return valueList; + } +} diff --git a/src/main/java/org/opensearch/ad/model/InitProgressProfile.java b/src/main/java/org/opensearch/ad/model/InitProgressProfile.java index c01259800..4147f8ef4 100644 --- a/src/main/java/org/opensearch/ad/model/InitProgressProfile.java +++ b/src/main/java/org/opensearch/ad/model/InitProgressProfile.java @@ -16,9 +16,9 @@ import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.commons.lang.builder.ToStringBuilder; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/model/InitProgressProfile.java-e b/src/main/java/org/opensearch/ad/model/InitProgressProfile.java-e new file mode 100644 index 000000000..4147f8ef4 --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/InitProgressProfile.java-e @@ -0,0 +1,127 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.io.IOException; + +import org.apache.commons.lang.builder.EqualsBuilder; +import org.apache.commons.lang.builder.HashCodeBuilder; +import org.apache.commons.lang.builder.ToStringBuilder; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; + +/** + * Profile output for detector initialization progress. When the new detector is created, it is possible that + * there hasn’t been enough continuous data in the index. We need to use live data to initialize. + * During initialization, we need to tell users progress (using a percentage), how many more + * shingles to go, and approximately how many minutes before the detector becomes operational + * if they keep their data stream continuous. + * + */ +public class InitProgressProfile implements Writeable, ToXContent { + // field name in toXContent + public static final String PERCENTAGE = "percentage"; + public static final String ESTIMATED_MINUTES_LEFT = "estimated_minutes_left"; + public static final String NEEDED_SHINGLES = "needed_shingles"; + + private final String percentage; + private final long estimatedMinutesLeft; + private final int neededShingles; + + public InitProgressProfile(String percentage, long estimatedMinutesLeft, int neededDataPoints) { + super(); + this.percentage = percentage; + this.estimatedMinutesLeft = estimatedMinutesLeft; + this.neededShingles = neededDataPoints; + } + + public InitProgressProfile(StreamInput in) throws IOException { + percentage = in.readString(); + estimatedMinutesLeft = in.readVLong(); + neededShingles = in.readVInt(); + } + + public String getPercentage() { + return percentage; + } + + public long getEstimatedMinutesLeft() { + return estimatedMinutesLeft; + } + + public int getNeededDataPoints() { + return neededShingles; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(PERCENTAGE, percentage); + if (estimatedMinutesLeft > 0) { + builder.field(ESTIMATED_MINUTES_LEFT, estimatedMinutesLeft); + } + if (neededShingles > 0) { + builder.field(NEEDED_SHINGLES, neededShingles); + } + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(percentage); + out.writeVLong(estimatedMinutesLeft); + out.writeVInt(neededShingles); + } + + @Override + public String toString() { + ToStringBuilder builder = new ToStringBuilder(this); + builder.append(PERCENTAGE, percentage); + if (estimatedMinutesLeft > 0) { + builder.append(ESTIMATED_MINUTES_LEFT, estimatedMinutesLeft); + } + if (neededShingles > 0) { + builder.append(NEEDED_SHINGLES, neededShingles); + } + return builder.toString(); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + if (obj instanceof InitProgressProfile) { + InitProgressProfile other = (InitProgressProfile) obj; + + EqualsBuilder equalsBuilder = new EqualsBuilder(); + equalsBuilder.append(percentage, other.percentage); + equalsBuilder.append(estimatedMinutesLeft, other.estimatedMinutesLeft); + equalsBuilder.append(neededShingles, other.neededShingles); + + return equalsBuilder.isEquals(); + } + return false; + } + + @Override + public int hashCode() { + return new HashCodeBuilder().append(percentage).append(estimatedMinutesLeft).append(neededShingles).toHashCode(); + } +} diff --git a/src/main/java/org/opensearch/ad/model/Mergeable.java-e b/src/main/java/org/opensearch/ad/model/Mergeable.java-e new file mode 100644 index 000000000..980dad1a4 --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/Mergeable.java-e @@ -0,0 +1,16 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +public interface Mergeable { + void merge(Mergeable other); +} diff --git a/src/main/java/org/opensearch/ad/model/MergeableList.java-e b/src/main/java/org/opensearch/ad/model/MergeableList.java-e new file mode 100644 index 000000000..4bb0d7842 --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/MergeableList.java-e @@ -0,0 +1,38 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.util.List; + +public class MergeableList implements Mergeable { + + private final List elements; + + public List getElements() { + return elements; + } + + public MergeableList(List elements) { + this.elements = elements; + } + + @Override + public void merge(Mergeable other) { + if (this == other || other == null || getClass() != other.getClass()) { + return; + } + MergeableList otherList = (MergeableList) other; + if (otherList.getElements() != null) { + this.elements.addAll(otherList.getElements()); + } + } +} diff --git a/src/main/java/org/opensearch/ad/model/ModelProfile.java b/src/main/java/org/opensearch/ad/model/ModelProfile.java index b98860f33..1d6d0ce85 100644 --- a/src/main/java/org/opensearch/ad/model/ModelProfile.java +++ b/src/main/java/org/opensearch/ad/model/ModelProfile.java @@ -16,9 +16,9 @@ import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.commons.lang.builder.ToStringBuilder; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.constant.CommonName; diff --git a/src/main/java/org/opensearch/ad/model/ModelProfile.java-e b/src/main/java/org/opensearch/ad/model/ModelProfile.java-e new file mode 100644 index 000000000..1d6d0ce85 --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/ModelProfile.java-e @@ -0,0 +1,127 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.io.IOException; + +import org.apache.commons.lang.builder.EqualsBuilder; +import org.apache.commons.lang.builder.HashCodeBuilder; +import org.apache.commons.lang.builder.ToStringBuilder; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Entity; + +/** + * Used to show model information in profile API + * + */ +public class ModelProfile implements Writeable, ToXContentObject { + private final String modelId; + // added since Opensearch 1.1 + private final Entity entity; + private final long modelSizeInBytes; + + public ModelProfile(String modelId, Entity entity, long modelSizeInBytes) { + super(); + this.modelId = modelId; + this.entity = entity; + this.modelSizeInBytes = modelSizeInBytes; + } + + public ModelProfile(StreamInput in) throws IOException { + this.modelId = in.readString(); + if (in.readBoolean()) { + this.entity = new Entity(in); + } else { + this.entity = null; + } + + this.modelSizeInBytes = in.readLong(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + if (entity != null) { + out.writeBoolean(true); + entity.writeTo(out); + } else { + out.writeBoolean(false); + } + + out.writeLong(modelSizeInBytes); + } + + public String getModelId() { + return modelId; + } + + public Entity getEntity() { + return entity; + } + + public long getModelSizeInBytes() { + return modelSizeInBytes; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(CommonName.MODEL_ID_FIELD, modelId); + if (entity != null) { + builder.field(CommonName.ENTITY_KEY, entity); + } + if (modelSizeInBytes > 0) { + builder.field(CommonName.MODEL_SIZE_IN_BYTES, modelSizeInBytes); + } + return builder; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + if (obj instanceof ModelProfile) { + ModelProfile other = (ModelProfile) obj; + EqualsBuilder equalsBuilder = new EqualsBuilder(); + equalsBuilder.append(modelId, other.modelId); + + return equalsBuilder.isEquals(); + } + return false; + } + + @Override + public int hashCode() { + return new HashCodeBuilder().append(modelId).toHashCode(); + } + + @Override + public String toString() { + ToStringBuilder builder = new ToStringBuilder(this); + builder.append(CommonName.MODEL_ID_FIELD, modelId); + if (modelSizeInBytes > 0) { + builder.append(CommonName.MODEL_SIZE_IN_BYTES, modelSizeInBytes); + } + if (entity != null) { + builder.append(CommonName.ENTITY_KEY, entity); + } + return builder.toString(); + } +} diff --git a/src/main/java/org/opensearch/ad/model/ModelProfileOnNode.java b/src/main/java/org/opensearch/ad/model/ModelProfileOnNode.java index a78c90722..1e45bcc7a 100644 --- a/src/main/java/org/opensearch/ad/model/ModelProfileOnNode.java +++ b/src/main/java/org/opensearch/ad/model/ModelProfileOnNode.java @@ -17,9 +17,9 @@ import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.commons.lang.builder.ToStringBuilder; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/model/ModelProfileOnNode.java-e b/src/main/java/org/opensearch/ad/model/ModelProfileOnNode.java-e new file mode 100644 index 000000000..1e45bcc7a --- /dev/null +++ b/src/main/java/org/opensearch/ad/model/ModelProfileOnNode.java-e @@ -0,0 +1,105 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.io.IOException; + +import org.apache.commons.lang.builder.EqualsBuilder; +import org.apache.commons.lang.builder.HashCodeBuilder; +import org.apache.commons.lang.builder.ToStringBuilder; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; + +public class ModelProfileOnNode implements Writeable, ToXContent { + // field name in toXContent + public static final String NODE_ID = "node_id"; + + private final String nodeId; + private final ModelProfile modelProfile; + + public ModelProfileOnNode(String nodeId, ModelProfile modelProfile) { + this.nodeId = nodeId; + this.modelProfile = modelProfile; + } + + public ModelProfileOnNode(StreamInput in) throws IOException { + this.nodeId = in.readString(); + this.modelProfile = new ModelProfile(in); + } + + public String getModelId() { + return modelProfile.getModelId(); + } + + public long getModelSize() { + return modelProfile.getModelSizeInBytes(); + } + + public String getNodeId() { + return nodeId; + } + + public ModelProfile getModelProfile() { + return modelProfile; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + modelProfile.toXContent(builder, params); + builder.field(NODE_ID, nodeId); + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(nodeId); + modelProfile.writeTo(out); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + if (obj instanceof ModelProfileOnNode) { + ModelProfileOnNode other = (ModelProfileOnNode) obj; + EqualsBuilder equalsBuilder = new EqualsBuilder(); + equalsBuilder.append(modelProfile, other.modelProfile); + equalsBuilder.append(nodeId, other.nodeId); + + return equalsBuilder.isEquals(); + } + return false; + } + + @Override + public int hashCode() { + return new HashCodeBuilder().append(modelProfile).append(nodeId).toHashCode(); + } + + @Override + public String toString() { + ToStringBuilder builder = new ToStringBuilder(this); + builder.append(ADCommonName.MODEL, modelProfile); + builder.append(NODE_ID, nodeId); + return builder.toString(); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/BatchWorker.java b/src/main/java/org/opensearch/ad/ratelimit/BatchWorker.java index bbe3ffc00..50b051f6d 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/BatchWorker.java +++ b/src/main/java/org/opensearch/ad/ratelimit/BatchWorker.java @@ -20,13 +20,13 @@ import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionListener; import org.opensearch.action.support.ThreadedActionListener; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; /** * @@ -111,7 +111,7 @@ protected void execute(Runnable afterProcessCallback, Runnable emptyQueueCallbac ThreadedActionListener listener = new ThreadedActionListener<>( LOG, threadPool, - AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, getResponseListener(toProcess, batchRequest), false ); diff --git a/src/main/java/org/opensearch/ad/ratelimit/BatchWorker.java-e b/src/main/java/org/opensearch/ad/ratelimit/BatchWorker.java-e new file mode 100644 index 000000000..50b051f6d --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/BatchWorker.java-e @@ -0,0 +1,135 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import java.time.Clock; +import java.time.Duration; +import java.util.List; +import java.util.Random; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ThreadedActionListener; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +/** + * + * @param Individual request type that is a subtype of ADRequest + * @param Batch request type like BulkRequest + * @param Response type like BulkResponse + */ +public abstract class BatchWorker extends + ConcurrentWorker { + private static final Logger LOG = LogManager.getLogger(BatchWorker.class); + protected int batchSize; + + public BatchWorker( + String queueName, + long heapSize, + int singleRequestSize, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + ADCircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Setting concurrencySetting, + Duration executionTtl, + Setting batchSizeSetting, + Duration stateTtl, + NodeStateManager nodeStateManager + ) { + super( + queueName, + heapSize, + singleRequestSize, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + concurrencySetting, + executionTtl, + stateTtl, + nodeStateManager + ); + this.batchSize = batchSizeSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(batchSizeSetting, it -> batchSize = it); + } + + /** + * Used by subclasses to creates customized logic to send batch requests. + * After everything finishes, the method should call listener. + * @param request Batch request to execute + * @param listener customized listener + */ + protected abstract void executeBatchRequest(BatchRequestType request, ActionListener listener); + + /** + * We convert from queued requests understood by AD to batchRequest understood by OpenSearch. + * @param toProcess Queued requests + * @return batch requests + */ + protected abstract BatchRequestType toBatchRequest(List toProcess); + + @Override + protected void execute(Runnable afterProcessCallback, Runnable emptyQueueCallback) { + + List toProcess = getRequests(batchSize); + + // it is possible other concurrent threads have drained the queue + if (false == toProcess.isEmpty()) { + BatchRequestType batchRequest = toBatchRequest(toProcess); + + ThreadedActionListener listener = new ThreadedActionListener<>( + LOG, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + getResponseListener(toProcess, batchRequest), + false + ); + + final ActionListener listenerWithRelease = ActionListener.runAfter(listener, afterProcessCallback); + executeBatchRequest(batchRequest, listenerWithRelease); + } else { + emptyQueueCallback.run(); + } + } + + /** + * Used by subclasses to creates customized logic to handle batch responses + * or errors. + * @param toProcess Queued request used to retrieve information of retrying requests + * @param batchRequest Batch request corresponding to toProcess. We convert + * from toProcess understood by AD to batchRequest understood by ES. + * @return Listener to BatchResponse + */ + protected abstract ActionListener getResponseListener(List toProcess, BatchRequestType batchRequest); +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapter.java-e b/src/main/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapter.java-e new file mode 100644 index 000000000..91382a4b5 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapter.java-e @@ -0,0 +1,110 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Map; +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.ml.CheckpointDao; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.util.DateUtils; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.Strings; + +public class CheckPointMaintainRequestAdapter { + private static final Logger LOG = LogManager.getLogger(CheckPointMaintainRequestAdapter.class); + private CacheProvider cache; + private CheckpointDao checkpointDao; + private String indexName; + private Duration checkpointInterval; + private Clock clock; + + public CheckPointMaintainRequestAdapter( + CacheProvider cache, + CheckpointDao checkpointDao, + String indexName, + Setting checkpointIntervalSetting, + Clock clock, + ClusterService clusterService, + Settings settings + ) { + this.cache = cache; + this.checkpointDao = checkpointDao; + this.indexName = indexName; + + this.checkpointInterval = DateUtils.toDuration(checkpointIntervalSetting.get(settings)); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(checkpointIntervalSetting, it -> this.checkpointInterval = DateUtils.toDuration(it)); + + this.clock = clock; + } + + public Optional convert(CheckpointMaintainRequest request) { + String detectorId = request.getId(); + String modelId = request.getEntityModelId(); + + Optional> stateToMaintain = cache.get().getForMaintainance(detectorId, modelId); + if (!stateToMaintain.isEmpty()) { + ModelState state = stateToMaintain.get(); + Instant instant = state.getLastCheckpointTime(); + if (!checkpointDao.shouldSave(instant, false, checkpointInterval, clock)) { + return Optional.empty(); + } + + try { + Map source = checkpointDao.toIndexSource(state); + + // the model state is bloated or empty (empty samples and models), skip + if (source == null || source.isEmpty() || Strings.isEmpty(modelId)) { + return Optional.empty(); + } + + return Optional + .of( + new CheckpointWriteRequest( + request.getExpirationEpochMs(), + detectorId, + request.getPriority(), + // If the document does not already exist, the contents of the upsert element + // are inserted as a new document. + // If the document exists, update fields in the map + new UpdateRequest(indexName, modelId).docAsUpsert(true).doc(source) + ) + ); + } catch (Exception e) { + // Example exception: + // ConcurrentModificationException when calling toIndexSource + // and updating rcf model at the same time. To prevent this, + // we need to have a deep copy of models or have a lock. Both + // options are costly. + // As we are gonna retry serializing either when the entity is + // evicted out of cache or during the next maintenance period, + // don't do anything when the exception happens. + LOG.error(new ParameterizedMessage("Exception while serializing models for [{}]", modelId), e); + } + } + return Optional.empty(); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainRequest.java-e b/src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainRequest.java-e new file mode 100644 index 000000000..28fdfcc91 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainRequest.java-e @@ -0,0 +1,25 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +public class CheckpointMaintainRequest extends QueuedRequest { + private String entityModelId; + + public CheckpointMaintainRequest(long expirationEpochMs, String detectorId, RequestPriority priority, String entityModelId) { + super(expirationEpochMs, detectorId, priority); + this.entityModelId = entityModelId; + } + + public String getEntityModelId() { + return entityModelId; + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorker.java-e b/src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorker.java-e new file mode 100644 index 000000000..049b2d587 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorker.java-e @@ -0,0 +1,104 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS; + +import java.time.Clock; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Random; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; + +public class CheckpointMaintainWorker extends ScheduledWorker { + private static final Logger LOG = LogManager.getLogger(CheckpointMaintainWorker.class); + public static final String WORKER_NAME = "checkpoint-maintain"; + + private CheckPointMaintainRequestAdapter adapter; + + public CheckpointMaintainWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + ADCircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + CheckpointWriteWorker checkpointWriteQueue, + Duration stateTtl, + NodeStateManager nodeStateManager, + CheckPointMaintainRequestAdapter adapter + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + checkpointWriteQueue, + stateTtl, + nodeStateManager + ); + + this.batchSize = AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, it -> this.batchSize = it); + + this.expectedExecutionTimeInMilliSecsPerRequest = AnomalyDetectorSettings.AD_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS + .get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer( + AD_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS, + it -> this.expectedExecutionTimeInMilliSecsPerRequest = it + ); + this.adapter = adapter; + } + + @Override + protected List transformRequests(List requests) { + List allRequests = new ArrayList<>(); + for (CheckpointMaintainRequest request : requests) { + Optional converted = adapter.convert(request); + if (!converted.isEmpty()) { + allRequests.add(converted.get()); + } + } + return allRequests; + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointReadWorker.java-e b/src/main/java/org/opensearch/ad/ratelimit/CheckpointReadWorker.java-e new file mode 100644 index 000000000..e06d3e08e --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/CheckpointReadWorker.java-e @@ -0,0 +1,436 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Random; +import java.util.Set; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.MultiGetItemResponse; +import org.opensearch.action.get.MultiGetRequest; +import org.opensearch.action.get.MultiGetResponse; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.CheckpointDao; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.stats.ADStats; +import org.opensearch.ad.util.ExceptionUtil; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.util.ParseUtils; + +/** + * a queue for loading model checkpoint. The read is a multi-get query. Possible results are: + * a). If a checkpoint is not found, we forward that request to the cold start queue. + * b). When a request gets errors, the queue does not change its expiry time and puts + * that request to the end of the queue and automatically retries them before they expire. + * c) When a checkpoint is found, we load that point to memory and score the input + * data point and save the result if a complete model exists. Otherwise, we enqueue + * the sample. If we can host that model in memory (e.g., there is enough memory), + * we put the loaded model to cache. Otherwise (e.g., a cold entity), we write the + * updated checkpoint back to disk. + * + */ +public class CheckpointReadWorker extends BatchWorker { + private static final Logger LOG = LogManager.getLogger(CheckpointReadWorker.class); + public static final String WORKER_NAME = "checkpoint-read"; + private final ModelManager modelManager; + private final CheckpointDao checkpointDao; + private final EntityColdStartWorker entityColdStartQueue; + private final ResultWriteWorker resultWriteQueue; + private final ADIndexManagement indexUtil; + private final CacheProvider cacheProvider; + private final CheckpointWriteWorker checkpointWriteQueue; + private final ADStats adStats; + + public CheckpointReadWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + ADCircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ModelManager modelManager, + CheckpointDao checkpointDao, + EntityColdStartWorker entityColdStartQueue, + ResultWriteWorker resultWriteQueue, + NodeStateManager stateManager, + ADIndexManagement indexUtil, + CacheProvider cacheProvider, + Duration stateTtl, + CheckpointWriteWorker checkpointWriteQueue, + ADStats adStats + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + AD_CHECKPOINT_READ_QUEUE_CONCURRENCY, + executionTtl, + AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE, + stateTtl, + stateManager + ); + + this.modelManager = modelManager; + this.checkpointDao = checkpointDao; + this.entityColdStartQueue = entityColdStartQueue; + this.resultWriteQueue = resultWriteQueue; + this.indexUtil = indexUtil; + this.cacheProvider = cacheProvider; + this.checkpointWriteQueue = checkpointWriteQueue; + this.adStats = adStats; + } + + @Override + protected void executeBatchRequest(MultiGetRequest request, ActionListener listener) { + checkpointDao.batchRead(request, listener); + } + + /** + * Convert the input list of EntityFeatureRequest to a multi-get request. + * RateLimitedRequestWorker.getRequests has already limited the number of + * requests in the input list. So toBatchRequest method can take the input + * and send the multi-get directly. + * @return The converted multi-get request + */ + @Override + protected MultiGetRequest toBatchRequest(List toProcess) { + MultiGetRequest multiGetRequest = new MultiGetRequest(); + for (EntityRequest request : toProcess) { + Optional modelId = request.getModelId(); + if (false == modelId.isPresent()) { + continue; + } + multiGetRequest.add(new MultiGetRequest.Item(ADCommonName.CHECKPOINT_INDEX_NAME, modelId.get())); + } + return multiGetRequest; + } + + @Override + protected ActionListener getResponseListener(List toProcess, MultiGetRequest batchRequest) { + return ActionListener.wrap(response -> { + final MultiGetItemResponse[] itemResponses = response.getResponses(); + Map successfulRequests = new HashMap<>(); + + // lazy init since we don't expect retryable requests to happen often + Set retryableRequests = null; + Set notFoundModels = null; + boolean printedUnexpectedFailure = false; + // contain requests that we will set the detector's exception to + // EndRunException (stop now = false) + Map stopDetectorRequests = null; + for (MultiGetItemResponse itemResponse : itemResponses) { + String modelId = itemResponse.getId(); + if (itemResponse.isFailed()) { + final Exception failure = itemResponse.getFailure().getFailure(); + if (failure instanceof IndexNotFoundException) { + for (EntityRequest origRequest : toProcess) { + // If it is checkpoint index not found exception, I don't + // need to retry as checkpoint read is bound to fail. Just + // send everything to the cold start queue and return. + entityColdStartQueue.put(origRequest); + } + return; + } else if (ExceptionUtil.isRetryAble(failure)) { + if (retryableRequests == null) { + retryableRequests = new HashSet<>(); + } + retryableRequests.add(modelId); + } else if (ExceptionUtil.isOverloaded(failure)) { + LOG.error("too many get AD model checkpoint requests or shard not available"); + setCoolDownStart(); + } else { + // some unexpected bug occurred or cluster is unstable (e.g., ClusterBlockException) or index is red (e.g. + // NoShardAvailableActionException) while fetching a checkpoint. As this might happen for a large amount + // of entities, we don't want to flood logs with such exception trace. Only print it once. + if (!printedUnexpectedFailure) { + LOG.error("Unexpected failure", failure); + printedUnexpectedFailure = true; + } + if (stopDetectorRequests == null) { + stopDetectorRequests = new HashMap<>(); + } + stopDetectorRequests.put(modelId, failure); + } + } else if (!itemResponse.getResponse().isExists()) { + // lazy init as we don't expect retrying happens often + if (notFoundModels == null) { + notFoundModels = new HashSet<>(); + } + notFoundModels.add(modelId); + } else { + successfulRequests.put(modelId, itemResponse); + } + } + + // deal with not found model + if (notFoundModels != null) { + for (EntityRequest origRequest : toProcess) { + Optional modelId = origRequest.getModelId(); + if (modelId.isPresent() && notFoundModels.contains(modelId.get())) { + // submit to cold start queue + entityColdStartQueue.put(origRequest); + } + } + } + + // deal with failures that we will retry for a limited amount of times + // before stopping the detector + // We cannot just loop over stopDetectorRequests instead of toProcess + // because we need detector id from toProcess' elements. stopDetectorRequests only has model id. + if (stopDetectorRequests != null) { + for (EntityRequest origRequest : toProcess) { + Optional modelId = origRequest.getModelId(); + if (modelId.isPresent() && stopDetectorRequests.containsKey(modelId.get())) { + String adID = origRequest.detectorId; + nodeStateManager + .setException( + adID, + new EndRunException(adID, CommonMessages.BUG_RESPONSE, stopDetectorRequests.get(modelId.get()), false) + ); + } + } + } + + if (successfulRequests.isEmpty() && (retryableRequests == null || retryableRequests.isEmpty())) { + // don't need to proceed further since no checkpoint is available + return; + } + + processCheckpointIteration(0, toProcess, successfulRequests, retryableRequests); + }, exception -> { + if (ExceptionUtil.isOverloaded(exception)) { + LOG.error("too many get AD model checkpoint requests or shard not available"); + setCoolDownStart(); + } else if (ExceptionUtil.isRetryAble(exception)) { + // retry all of them + putAll(toProcess); + } else { + LOG.error("Fail to restore models", exception); + } + }); + } + + private void processCheckpointIteration( + int i, + List toProcess, + Map successfulRequests, + Set retryableRequests + ) { + if (i >= toProcess.size()) { + return; + } + + // whether we will process next response in callbacks + // if false, finally will process next checkpoints + boolean processNextInCallBack = false; + try { + EntityFeatureRequest origRequest = toProcess.get(i); + + Optional modelIdOptional = origRequest.getModelId(); + if (false == modelIdOptional.isPresent()) { + return; + } + + String detectorId = origRequest.getId(); + Entity entity = origRequest.getEntity(); + + String modelId = modelIdOptional.get(); + + MultiGetItemResponse checkpointResponse = successfulRequests.get(modelId); + + if (checkpointResponse != null) { + // successful requests + Optional> checkpoint = checkpointDao + .processGetResponse(checkpointResponse.getResponse(), modelId); + + if (false == checkpoint.isPresent()) { + // checkpoint is too big + return; + } + + nodeStateManager + .getAnomalyDetector( + detectorId, + onGetDetector( + origRequest, + i, + detectorId, + toProcess, + successfulRequests, + retryableRequests, + checkpoint, + entity, + modelId + ) + ); + processNextInCallBack = true; + } else if (retryableRequests != null && retryableRequests.contains(modelId)) { + // failed requests + super.put(origRequest); + } + } finally { + if (false == processNextInCallBack) { + processCheckpointIteration(i + 1, toProcess, successfulRequests, retryableRequests); + } + } + } + + private ActionListener> onGetDetector( + EntityFeatureRequest origRequest, + int index, + String detectorId, + List toProcess, + Map successfulRequests, + Set retryableRequests, + Optional> checkpoint, + Entity entity, + String modelId + ) { + return ActionListener.wrap(detectorOptional -> { + if (false == detectorOptional.isPresent()) { + LOG.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); + processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests); + return; + } + + AnomalyDetector detector = detectorOptional.get(); + + ModelState modelState = modelManager + .processEntityCheckpoint(checkpoint, entity, modelId, detectorId, detector.getShingleSize()); + + ThresholdingResult result = null; + try { + result = modelManager + .getAnomalyResultForEntity(origRequest.getCurrentFeature(), modelState, modelId, entity, detector.getShingleSize()); + } catch (IllegalArgumentException e) { + // fail to score likely due to model corruption. Re-cold start to recover. + LOG.error(new ParameterizedMessage("Likely model corruption for [{}]", origRequest.getModelId()), e); + adStats.getStat(StatNames.MODEL_CORRUTPION_COUNT.getName()).increment(); + if (origRequest.getModelId().isPresent()) { + String entityModelId = origRequest.getModelId().get(); + checkpointDao + .deleteModelCheckpoint( + entityModelId, + ActionListener + .wrap( + r -> LOG.debug(new ParameterizedMessage("Succeeded in deleting checkpoint [{}].", entityModelId)), + ex -> LOG.error(new ParameterizedMessage("Failed to delete checkpoint [{}].", entityModelId), ex) + ) + ); + } + + entityColdStartQueue.put(origRequest); + processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests); + return; + } + + if (result != null && result.getRcfScore() > 0) { + RequestPriority requestPriority = result.getGrade() > 0 ? RequestPriority.HIGH : RequestPriority.MEDIUM; + + List resultsToSave = result + .toIndexableResults( + detector, + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis()), + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis() + detector.getIntervalInMilliseconds()), + Instant.now(), + Instant.now(), + ParseUtils.getFeatureData(origRequest.getCurrentFeature(), detector), + Optional.ofNullable(entity), + indexUtil.getSchemaVersion(ADIndex.RESULT), + modelId, + null, + null + ); + + for (AnomalyResult r : resultsToSave) { + resultWriteQueue + .put( + new ResultWriteRequest( + origRequest.getExpirationEpochMs(), + detectorId, + requestPriority, + r, + detector.getCustomResultIndex() + ) + ); + } + } + + // try to load to cache + boolean loaded = cacheProvider.get().hostIfPossible(detector, modelState); + + if (false == loaded) { + // not in memory. Maybe cold entities or some other entities + // have filled the slot while waiting for loading checkpoints. + checkpointWriteQueue.write(modelState, true, RequestPriority.LOW); + } + + processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests); + }, exception -> { + LOG.error(new ParameterizedMessage("fail to get checkpoint [{}]", modelId, exception)); + nodeStateManager.setException(detectorId, exception); + processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests); + }); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteRequest.java-e b/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteRequest.java-e new file mode 100644 index 000000000..9c41e55be --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteRequest.java-e @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import org.opensearch.action.update.UpdateRequest; + +public class CheckpointWriteRequest extends QueuedRequest { + private final UpdateRequest updateRequest; + + public CheckpointWriteRequest(long expirationEpochMs, String detectorId, RequestPriority priority, UpdateRequest updateRequest) { + super(expirationEpochMs, detectorId, priority); + this.updateRequest = updateRequest; + } + + public UpdateRequest getUpdateRequest() { + return updateRequest; + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteWorker.java-e b/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteWorker.java-e new file mode 100644 index 000000000..dd32e21c4 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteWorker.java-e @@ -0,0 +1,274 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Random; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.ActionListener; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.ml.CheckpointDao; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.util.ExceptionUtil; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.Strings; +import org.opensearch.threadpool.ThreadPool; + +public class CheckpointWriteWorker extends BatchWorker { + private static final Logger LOG = LogManager.getLogger(CheckpointWriteWorker.class); + public static final String WORKER_NAME = "checkpoint-write"; + + private final CheckpointDao checkpoint; + private final String indexName; + private final Duration checkpointInterval; + + public CheckpointWriteWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + ADCircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + CheckpointDao checkpoint, + String indexName, + Duration checkpointInterval, + NodeStateManager stateManager, + Duration stateTtl + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + AD_CHECKPOINT_WRITE_QUEUE_CONCURRENCY, + executionTtl, + AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, + stateTtl, + stateManager + ); + this.checkpoint = checkpoint; + this.indexName = indexName; + this.checkpointInterval = checkpointInterval; + } + + @Override + protected void executeBatchRequest(BulkRequest request, ActionListener listener) { + checkpoint.batchWrite(request, listener); + } + + @Override + protected BulkRequest toBatchRequest(List toProcess) { + final BulkRequest bulkRequest = new BulkRequest(); + for (CheckpointWriteRequest request : toProcess) { + bulkRequest.add(request.getUpdateRequest()); + } + return bulkRequest; + } + + @Override + protected ActionListener getResponseListener(List toProcess, BulkRequest batchRequest) { + return ActionListener.wrap(response -> { + for (BulkItemResponse r : response.getItems()) { + if (r.getFailureMessage() != null) { + // maybe indicating a bug + // don't retry failed requests since checkpoints are too large (250KB+) + // Later maintenance window or cold start or cache remove will retry saving + LOG.error(r.getFailureMessage()); + } + } + }, exception -> { + if (ExceptionUtil.isOverloaded(exception)) { + LOG.error("too many get AD model checkpoint requests or shard not avialble"); + setCoolDownStart(); + } + + for (CheckpointWriteRequest request : toProcess) { + nodeStateManager.setException(request.getId(), exception); + } + + // don't retry failed requests since checkpoints are too large (250KB+) + // Later maintenance window or cold start or cache remove will retry saving + LOG.error("Fail to save models", exception); + }); + } + + /** + * Prepare bulking the input model state to the checkpoint index. + * We don't save checkpoints within checkpointInterval again, except this + * is a high priority request (e.g., from cold start). + * This method will update the input state's last checkpoint time if the + * checkpoint is staged (ready to be written in the next batch). + * @param modelState Model state + * @param forceWrite whether we should write no matter what + * @param priority how urgent the write is + */ + public void write(ModelState modelState, boolean forceWrite, RequestPriority priority) { + Instant instant = modelState.getLastCheckpointTime(); + if (!checkpoint.shouldSave(instant, forceWrite, checkpointInterval, clock)) { + return; + } + + if (modelState.getModel() != null) { + String detectorId = modelState.getId(); + String modelId = modelState.getModelId(); + if (modelId == null || detectorId == null) { + return; + } + + nodeStateManager.getAnomalyDetector(detectorId, onGetDetector(detectorId, modelId, modelState, priority)); + } + } + + private ActionListener> onGetDetector( + String detectorId, + String modelId, + ModelState modelState, + RequestPriority priority + ) { + return ActionListener.wrap(detectorOptional -> { + if (false == detectorOptional.isPresent()) { + LOG.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); + return; + } + + AnomalyDetector detector = detectorOptional.get(); + try { + Map source = checkpoint.toIndexSource(modelState); + + // the model state is bloated or we have bugs, skip + if (source == null || source.isEmpty()) { + return; + } + + modelState.setLastCheckpointTime(clock.instant()); + CheckpointWriteRequest request = new CheckpointWriteRequest( + System.currentTimeMillis() + detector.getIntervalInMilliseconds(), + detectorId, + priority, + // If the document does not already exist, the contents of the upsert element + // are inserted as a new document. + // If the document exists, update fields in the map + new UpdateRequest(indexName, modelId).docAsUpsert(true).doc(source) + ); + + put(request); + } catch (Exception e) { + // Example exception: + // ConcurrentModificationException when calling toCheckpoint + // and updating rcf model at the same time. To prevent this, + // we need to have a deep copy of models or have a lock. Both + // options are costly. + // As we are gonna retry serializing either when the entity is + // evicted out of cache or during the next maintenance period, + // don't do anything when the exception happens. + LOG.error(new ParameterizedMessage("Exception while serializing models for [{}]", modelId), e); + } + + }, exception -> { LOG.error(new ParameterizedMessage("fail to get detector [{}]", detectorId), exception); }); + } + + public void writeAll(List> modelStates, String detectorId, boolean forceWrite, RequestPriority priority) { + ActionListener> onGetForAll = ActionListener.wrap(detectorOptional -> { + if (false == detectorOptional.isPresent()) { + LOG.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); + return; + } + + AnomalyDetector detector = detectorOptional.get(); + try { + List allRequests = new ArrayList<>(); + for (ModelState state : modelStates) { + Instant instant = state.getLastCheckpointTime(); + if (!checkpoint.shouldSave(instant, forceWrite, checkpointInterval, clock)) { + continue; + } + + Map source = checkpoint.toIndexSource(state); + String modelId = state.getModelId(); + + // the model state is bloated or empty (empty samples and models), skip + if (source == null || source.isEmpty() || Strings.isEmpty(modelId)) { + continue; + } + + state.setLastCheckpointTime(clock.instant()); + allRequests + .add( + new CheckpointWriteRequest( + System.currentTimeMillis() + detector.getIntervalInMilliseconds(), + detectorId, + priority, + // If the document does not already exist, the contents of the upsert element + // are inserted as a new document. + // If the document exists, update fields in the map + new UpdateRequest(indexName, modelId).docAsUpsert(true).doc(source) + ) + ); + } + + putAll(allRequests); + } catch (Exception e) { + // Example exception: + // ConcurrentModificationException when calling toCheckpoint + // and updating rcf model at the same time. To prevent this, + // we need to have a deep copy of models or have a lock. Both + // options are costly. + // As we are gonna retry serializing either when the entity is + // evicted out of cache or during the next maintenance period, + // don't do anything when the exception happens. + LOG.info(new ParameterizedMessage("Exception while serializing models for [{}]", detectorId), e); + } + + }, exception -> { LOG.error(new ParameterizedMessage("fail to get detector [{}]", detectorId), exception); }); + + nodeStateManager.getAnomalyDetector(detectorId, onGetForAll); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ColdEntityWorker.java-e b/src/main/java/org/opensearch/ad/ratelimit/ColdEntityWorker.java-e new file mode 100644 index 000000000..fb834e089 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ColdEntityWorker.java-e @@ -0,0 +1,105 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS; + +import java.time.Clock; +import java.time.Duration; +import java.util.List; +import java.util.Random; +import java.util.stream.Collectors; + +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; + +/** + * A queue slowly releasing low-priority requests to CheckpointReadQueue + * + * ColdEntityQueue is a queue to absorb cold entities. Like hot entities, we load a cold + * entity's model checkpoint from disk, train models if the checkpoint is not found, + * query for missed features to complete a shingle, use the models to check whether + * the incoming feature is normal, update models, and save the detection results to disks.  + * Implementation-wise, we reuse the queues we have developed for hot entities. + * The differences are: we process hot entities as long as resources (e.g., AD + * thread pool has availability) are available, while we release cold entity requests + * to other queues at a slow controlled pace. Also, cold entity requests' priority is low. + * So only when there are no hot entity requests to process are we going to process cold + * entity requests.  + * + */ +public class ColdEntityWorker extends ScheduledWorker { + public static final String WORKER_NAME = "cold-entity"; + + public ColdEntityWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + ADCircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + CheckpointReadWorker checkpointReadQueue, + Duration stateTtl, + NodeStateManager nodeStateManager + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + checkpointReadQueue, + stateTtl, + nodeStateManager + ); + + this.batchSize = AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE, it -> this.batchSize = it); + + this.expectedExecutionTimeInMilliSecsPerRequest = AnomalyDetectorSettings.EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS + .get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer( + EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, + it -> this.expectedExecutionTimeInMilliSecsPerRequest = it + ); + } + + @Override + protected List transformRequests(List requests) { + // guarantee we only send low priority requests + return requests.stream().filter(request -> request.priority == RequestPriority.LOW).collect(Collectors.toList()); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ConcurrentWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ConcurrentWorker.java index 9861e5056..62bd0a2bd 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/ConcurrentWorker.java +++ b/src/main/java/org/opensearch/ad/ratelimit/ConcurrentWorker.java @@ -19,13 +19,13 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; /** * A queue to run concurrent requests (either batch or single request). @@ -132,7 +132,7 @@ public void maintenance() { */ @Override protected void triggerProcess() { - threadPool.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME).execute(() -> { + threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME).execute(() -> { if (permits.tryAcquire()) { try { lastExecuteTime = clock.instant(); diff --git a/src/main/java/org/opensearch/ad/ratelimit/ConcurrentWorker.java-e b/src/main/java/org/opensearch/ad/ratelimit/ConcurrentWorker.java-e new file mode 100644 index 000000000..62bd0a2bd --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ConcurrentWorker.java-e @@ -0,0 +1,161 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Random; +import java.util.concurrent.Semaphore; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +/** + * A queue to run concurrent requests (either batch or single request). + * The concurrency is configurable. The callers use the put method to put requests + * in and the queue tries to execute them if there are concurrency slots. + * + * @param Individual request type that is a subtype of ADRequest + */ +public abstract class ConcurrentWorker extends RateLimitedRequestWorker { + private static final Logger LOG = LogManager.getLogger(ConcurrentWorker.class); + + private Semaphore permits; + + private Instant lastExecuteTime; + private Duration executionTtl; + + /** + * + * Constructor with dependencies and configuration. + * + * @param queueName queue's name + * @param heapSizeInBytes ES heap size + * @param singleRequestSizeInBytes single request's size in bytes + * @param maxHeapPercentForQueueSetting max heap size used for the queue. Used for + * rate AD's usage on ES threadpools. + * @param clusterService Cluster service accessor + * @param random Random number generator + * @param adCircuitBreakerService AD Circuit breaker service + * @param threadPool threadpool accessor + * @param settings Cluster settings getter + * @param maxQueuedTaskRatio maximum queued tasks ratio in ES threadpools + * @param clock Clock to get current time + * @param mediumSegmentPruneRatio the percent of medium priority requests to prune when the queue is full + * @param lowSegmentPruneRatio the percent of low priority requests to prune when the queue is full + * @param maintenanceFreqConstant a constant help define the frequency of maintenance. We cannot do + * the expensive maintenance too often. + * @param concurrencySetting Max concurrent processing of the queued events + * @param executionTtl Max execution time of a single request + * @param stateTtl max idle state duration. Used to clean unused states. + * @param nodeStateManager node state accessor + */ + public ConcurrentWorker( + String queueName, + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + ADCircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Setting concurrencySetting, + Duration executionTtl, + Duration stateTtl, + NodeStateManager nodeStateManager + ) { + super( + queueName, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + stateTtl, + nodeStateManager + ); + + this.permits = new Semaphore(concurrencySetting.get(settings)); + clusterService.getClusterSettings().addSettingsUpdateConsumer(concurrencySetting, it -> permits = new Semaphore(it)); + + this.lastExecuteTime = clock.instant(); + this.executionTtl = executionTtl; + } + + @Override + public void maintenance() { + super.maintenance(); + + if (lastExecuteTime.plus(executionTtl).isBefore(clock.instant()) && permits.availablePermits() == 0 && false == isQueueEmpty()) { + LOG.warn("previous execution has been running for too long. Maybe there are bugs."); + + // Release one permit. This is a stop gap solution as I don't know + // whether the system is under heavy workload or not. Release multiple + // permits might cause the situation even worse. So I am conservative here. + permits.release(); + } + } + + /** + * try to execute queued requests if there are concurrency slots and return right away. + */ + @Override + protected void triggerProcess() { + threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME).execute(() -> { + if (permits.tryAcquire()) { + try { + lastExecuteTime = clock.instant(); + execute(() -> { + permits.release(); + process(); + }, () -> { permits.release(); }); + } catch (Exception e) { + permits.release(); + // throw to the root level to catch + throw e; + } + } + }); + } + + /** + * Execute requests in toProcess. The implementation needs to call cleanUp after done. + * The 1st callback is executed after processing one request. So we keep looking for + * new requests if there is any after finishing one request. Otherwise, just release + * (the 2nd callback) without calling process. + * @param afterProcessCallback callback after processing requests + * @param emptyQueueCallback callback for empty queues + */ + protected abstract void execute(Runnable afterProcessCallback, Runnable emptyQueueCallback); +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/EntityColdStartWorker.java-e b/src/main/java/org/opensearch/ad/ratelimit/EntityColdStartWorker.java-e new file mode 100644 index 000000000..53d05ff11 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/EntityColdStartWorker.java-e @@ -0,0 +1,161 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.util.ArrayDeque; +import java.util.Locale; +import java.util.Optional; +import java.util.Random; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.ml.EntityColdStarter; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelManager.ModelType; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.util.ExceptionUtil; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; + +/** + * A queue for HCAD model training (a.k.a. cold start). As model training is a + * pretty expensive operation, we pull cold start requests from the queue in a + * serial fashion. Each detector has an equal chance of being pulled. The equal + * probability is achieved by putting model training requests for different + * detectors into different segments and pulling requests from segments in a + * round-robin fashion. + * + */ +public class EntityColdStartWorker extends SingleRequestWorker { + private static final Logger LOG = LogManager.getLogger(EntityColdStartWorker.class); + public static final String WORKER_NAME = "cold-start"; + + private final EntityColdStarter entityColdStarter; + private final CacheProvider cacheProvider; + + public EntityColdStartWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + ADCircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + EntityColdStarter entityColdStarter, + Duration stateTtl, + NodeStateManager nodeStateManager, + CacheProvider cacheProvider + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + ENTITY_COLD_START_QUEUE_CONCURRENCY, + executionTtl, + stateTtl, + nodeStateManager + ); + this.entityColdStarter = entityColdStarter; + this.cacheProvider = cacheProvider; + } + + @Override + protected void executeRequest(EntityRequest coldStartRequest, ActionListener listener) { + String detectorId = coldStartRequest.getId(); + + Optional modelId = coldStartRequest.getModelId(); + + if (false == modelId.isPresent()) { + String error = String.format(Locale.ROOT, "Fail to get model id for request %s", coldStartRequest); + LOG.warn(error); + listener.onFailure(new RuntimeException(error)); + return; + } + + ModelState modelState = new ModelState<>( + new EntityModel(coldStartRequest.getEntity(), new ArrayDeque<>(), null), + modelId.get(), + detectorId, + ModelType.ENTITY.getName(), + clock, + 0 + ); + + ActionListener coldStartListener = ActionListener.wrap(r -> { + nodeStateManager.getAnomalyDetector(detectorId, ActionListener.wrap(detectorOptional -> { + try { + if (!detectorOptional.isPresent()) { + LOG + .error( + new ParameterizedMessage( + "fail to load trained model [{}] to cache due to the detector not being found.", + modelState.getModelId() + ) + ); + return; + } + AnomalyDetector detector = detectorOptional.get(); + EntityModel model = modelState.getModel(); + // load to cache if cold start succeeds + if (model != null && model.getTrcf() != null) { + cacheProvider.get().hostIfPossible(detector, modelState); + } + } finally { + listener.onResponse(null); + } + }, listener::onFailure)); + + }, e -> { + try { + if (ExceptionUtil.isOverloaded(e)) { + LOG.error("OpenSearch is overloaded"); + setCoolDownStart(); + } + nodeStateManager.setException(detectorId, e); + } finally { + listener.onFailure(e); + } + }); + + entityColdStarter.trainModel(coldStartRequest.getEntity(), detectorId, modelState, coldStartListener); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/EntityFeatureRequest.java-e b/src/main/java/org/opensearch/ad/ratelimit/EntityFeatureRequest.java-e new file mode 100644 index 000000000..875974dbb --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/EntityFeatureRequest.java-e @@ -0,0 +1,40 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import org.opensearch.timeseries.model.Entity; + +public class EntityFeatureRequest extends EntityRequest { + private final double[] currentFeature; + private final long dataStartTimeMillis; + + public EntityFeatureRequest( + long expirationEpochMs, + String detectorId, + RequestPriority priority, + Entity entity, + double[] currentFeature, + long dataStartTimeMs + ) { + super(expirationEpochMs, detectorId, priority, entity); + this.currentFeature = currentFeature; + this.dataStartTimeMillis = dataStartTimeMs; + } + + public double[] getCurrentFeature() { + return currentFeature; + } + + public long getDataStartTimeMillis() { + return dataStartTimeMillis; + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/EntityRequest.java-e b/src/main/java/org/opensearch/ad/ratelimit/EntityRequest.java-e new file mode 100644 index 000000000..7acf2652a --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/EntityRequest.java-e @@ -0,0 +1,40 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import java.util.Optional; + +import org.opensearch.timeseries.model.Entity; + +public class EntityRequest extends QueuedRequest { + private final Entity entity; + + /** + * + * @param expirationEpochMs Expiry time of the request + * @param detectorId Detector Id + * @param priority the entity's priority + * @param entity the entity's attributes + */ + public EntityRequest(long expirationEpochMs, String detectorId, RequestPriority priority, Entity entity) { + super(expirationEpochMs, detectorId, priority); + this.entity = entity; + } + + public Entity getEntity() { + return entity; + } + + public Optional getModelId() { + return entity.getModelId(detectorId); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/QueuedRequest.java-e b/src/main/java/org/opensearch/ad/ratelimit/QueuedRequest.java-e new file mode 100644 index 000000000..66c440db9 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/QueuedRequest.java-e @@ -0,0 +1,57 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +public abstract class QueuedRequest { + protected long expirationEpochMs; + protected String detectorId; + protected RequestPriority priority; + + /** + * + * @param expirationEpochMs Request expiry time in milliseconds + * @param detectorId Detector Id + * @param priority how urgent the request is + */ + protected QueuedRequest(long expirationEpochMs, String detectorId, RequestPriority priority) { + this.expirationEpochMs = expirationEpochMs; + this.detectorId = detectorId; + this.priority = priority; + } + + protected QueuedRequest() {} + + public long getExpirationEpochMs() { + return expirationEpochMs; + } + + /** + * A queue consists of various segments with different priority. A queued + * request belongs one segment. The subtype will define the id. + * @return Segment Id + */ + public RequestPriority getPriority() { + return priority; + } + + public void setPriority(RequestPriority priority) { + this.priority = priority; + } + + public String getId() { + return detectorId; + } + + public void setDetectorId(String detectorId) { + this.detectorId = detectorId; + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/RateLimitedRequestWorker.java b/src/main/java/org/opensearch/ad/ratelimit/RateLimitedRequestWorker.java index 1eea96337..770c79b96 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/RateLimitedRequestWorker.java +++ b/src/main/java/org/opensearch/ad/ratelimit/RateLimitedRequestWorker.java @@ -33,7 +33,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.ExpiringState; import org.opensearch.ad.MaintenanceState; import org.opensearch.ad.NodeStateManager; @@ -44,6 +43,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.threadpool.ThreadPool; import org.opensearch.threadpool.ThreadPoolStats; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.common.exception.TimeSeriesException; /** @@ -551,7 +551,7 @@ protected void process() { } catch (Exception e) { LOG.error(new ParameterizedMessage("Fail to process requests in [{}].", this.workerName), e); } - }, new TimeValue(coolDownMinutes, TimeUnit.MINUTES), AnomalyDetectorPlugin.AD_THREAD_POOL_NAME); + }, new TimeValue(coolDownMinutes, TimeUnit.MINUTES), TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME); } else { try { triggerProcess(); diff --git a/src/main/java/org/opensearch/ad/ratelimit/RateLimitedRequestWorker.java-e b/src/main/java/org/opensearch/ad/ratelimit/RateLimitedRequestWorker.java-e new file mode 100644 index 000000000..770c79b96 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/RateLimitedRequestWorker.java-e @@ -0,0 +1,573 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.COOLDOWN_MINUTES; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.ad.ExpiringState; +import org.opensearch.ad.MaintenanceState; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.threadpool.ThreadPoolStats; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.common.exception.TimeSeriesException; + +/** + * HCAD can bombard Opensearch with “thundering herd” traffic, in which many entities + * make requests that need similar Opensearch reads/writes at approximately the same + * time. To remedy this issue we queue the requests and ensure that only a + * limited set of requests are out for Opensearch reads/writes. + * + * @param Individual request type that is a subtype of ADRequest + */ +public abstract class RateLimitedRequestWorker implements MaintenanceState { + /** + * Each request is associated with a RequestQueue. That is, a queue consists of RequestQueues. + * RequestQueues have their corresponding priorities: HIGH, MEDIUM, and LOW. An example + * of HIGH priority requests is anomaly results with errors or its anomaly grade + * larger than zero. An example of MEDIUM priority requests is a cold start request + * for an entity. An example of LOW priority requests is checkpoint write requests + * for a cold entity. LOW priority requests have the slightest chance to be selected + * to be executed. MEDIUM and HIGH priority requests have higher stakes. LOW priority + * requests have higher chances of being deleted when the size of the queue reaches + * beyond a limit compared to MEDIUM/HIGH priority requests. + * + */ + class RequestQueue implements ExpiringState { + /* + * last access time of the RequestQueue + * This does not have to be precise, just a signal for unused old RequestQueue + * that can be removed. It is fine if we have race condition. Don't want + * to synchronize the access as this could penalize performance. + */ + private Instant lastAccessTime; + // data structure to hold requests. Cannot be reassigned. This is to + // guarantee a RequestQueue's content cannot be null. + private final BlockingQueue content; + + RequestQueue() { + this.lastAccessTime = clock.instant(); + this.content = new LinkedBlockingQueue(); + } + + @Override + public boolean expired(Duration stateTtl) { + return expired(lastAccessTime, stateTtl, clock.instant()); + } + + public void put(RequestType request) throws InterruptedException { + this.content.put(request); + } + + public int size() { + return this.content.size(); + } + + public boolean isEmpty() { + return content.size() == 0; + } + + /** + * Remove requests in the queue + * @param numberToRemove number of requests to remove + * @return removed requests + */ + public int drain(int numberToRemove) { + int removed = 0; + while (removed <= numberToRemove) { + if (content.poll() != null) { + removed++; + } else { + // stop if the queue is empty + break; + } + } + return removed; + } + + /** + * Remove requests in the queue + * @param removeRatio the removing ratio + * @return removed requests + */ + public int drain(float removeRatio) { + int numberToRemove = (int) (content.size() * removeRatio); + return drain(numberToRemove); + } + + /** + * Remove expired requests + * + * In terms of request duration, HCAD throws a request out if it + * is older than the detector frequency. This duration limit frees + * up HCAD to work on newer requests in the subsequent detection + * interval instead of piling up requests that no longer matter. + * For example, loading model checkpoints for cache misses requires + * a queue configured in front of it. A request contains the checkpoint + * document Id and the expiry time, and the queue can hold a considerable + * volume of such requests since the size of the request is small. + * The expiry time is the start timestamp of the next detector run. + * Enforcing the expiry time places an upper bound on each request’s + * lifetime. + * + * @return the number of removed requests + */ + public int clearExpiredRequests() { + int removed = 0; + RequestType head = content.peek(); + while (head != null && head.getExpirationEpochMs() < clock.millis()) { + content.poll(); + removed++; + head = content.peek(); + } + return removed; + } + } + + private static final Logger LOG = LogManager.getLogger(RateLimitedRequestWorker.class); + + protected volatile int queueSize; + protected final String workerName; + private final long heapSize; + private final int singleRequestSize; + private float maxHeapPercentForQueue; + + // map from RequestQueue Id to its RequestQueue. + // For high priority requests, the RequestQueue id is RequestPriority.HIGH.name(). + // For low priority requests, the RequestQueue id is RequestPriority.LOW.name(). + // For medium priority requests, the RequestQueue id is detector id. The objective + // is to separate requests from different detectors and fairly process requests + // from each detector. + protected final ConcurrentSkipListMap requestQueues; + private String lastSelectedRequestQueueId; + protected Random random; + private ADCircuitBreakerService adCircuitBreakerService; + protected ThreadPool threadPool; + protected Instant cooldownStart; + protected int coolDownMinutes; + private float maxQueuedTaskRatio; + protected Clock clock; + private float mediumRequestQueuePruneRatio; + private float lowRequestQueuePruneRatio; + protected int maintenanceFreqConstant; + private final Duration stateTtl; + protected final NodeStateManager nodeStateManager; + + public RateLimitedRequestWorker( + String workerName, + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + ADCircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumRequestQueuePruneRatio, + float lowRequestQueuePruneRatio, + int maintenanceFreqConstant, + Duration stateTtl, + NodeStateManager nodeStateManager + ) { + this.heapSize = heapSizeInBytes; + this.singleRequestSize = singleRequestSizeInBytes; + this.maxHeapPercentForQueue = maxHeapPercentForQueueSetting.get(settings); + this.queueSize = (int) (heapSizeInBytes * maxHeapPercentForQueue / singleRequestSizeInBytes); + clusterService.getClusterSettings().addSettingsUpdateConsumer(maxHeapPercentForQueueSetting, it -> { + int oldQueueSize = queueSize; + this.maxHeapPercentForQueue = it; + this.queueSize = (int) (this.heapSize * maxHeapPercentForQueue / this.singleRequestSize); + LOG.info(new ParameterizedMessage("Queue size changed from [{}] to [{}]", oldQueueSize, queueSize)); + }); + + this.workerName = workerName; + this.random = random; + this.adCircuitBreakerService = adCircuitBreakerService; + this.threadPool = threadPool; + this.maxQueuedTaskRatio = maxQueuedTaskRatio; + this.clock = clock; + this.mediumRequestQueuePruneRatio = mediumRequestQueuePruneRatio; + this.lowRequestQueuePruneRatio = lowRequestQueuePruneRatio; + + this.lastSelectedRequestQueueId = null; + this.requestQueues = new ConcurrentSkipListMap<>(); + this.cooldownStart = Instant.MIN; + this.coolDownMinutes = (int) (COOLDOWN_MINUTES.get(settings).getMinutes()); + this.maintenanceFreqConstant = maintenanceFreqConstant; + this.stateTtl = stateTtl; + this.nodeStateManager = nodeStateManager; + } + + protected String getWorkerName() { + return workerName; + } + + /** + * To add fairness to multiple detectors, HCAD allocates queues at a per + * detector granularity and pulls off requests across similar queues in a + * round-robin fashion. This way, if one detector has a much higher + * cardinality than other detectors, the unfinished portion of that + * detector’s workload times out, and other detectors’ workloads continue + * operating with predictable performance. For example, for loading checkpoints, + * HCAD pulls off 10 requests from one detector’ queues, issues a mget request + * to ES, wait for it to finish, and then does it again for other detectors’ + * queues. If one queue does not have more than 10 requests, HCAD dequeues + * the next batches of messages in the round-robin schedule. + * @return next queue to fetch requests + */ + protected Optional> selectNextQueue() { + if (true == requestQueues.isEmpty()) { + return Optional.empty(); + } + + String startId = lastSelectedRequestQueueId; + try { + for (int i = 0; i < requestQueues.size(); i++) { + if (startId == null || requestQueues.size() == 1 || startId.equals(requestQueues.lastKey())) { + startId = requestQueues.firstKey(); + } else { + startId = requestQueues.higherKey(startId); + } + + if (startId.equals(RequestPriority.LOW.name())) { + continue; + } + + RequestQueue requestQueue = requestQueues.get(startId); + if (requestQueue == null) { + continue; + } + + requestQueue.clearExpiredRequests(); + + if (false == requestQueue.isEmpty()) { + return Optional.of(requestQueue.content); + } + } + + RequestQueue requestQueue = requestQueues.get(RequestPriority.LOW.name()); + + if (requestQueue != null) { + requestQueue.clearExpiredRequests(); + if (false == requestQueue.isEmpty()) { + return Optional.of(requestQueue.content); + } + } + // if we haven't find a non-empty queue , return empty. + return Optional.empty(); + } finally { + // it is fine we may have race conditions. We are not trying to + // be precise. The objective is to select each RequestQueue with equal probability. + lastSelectedRequestQueueId = startId; + } + } + + protected void putOnly(RequestType request) { + try { + // consider MEDIUM priority here because only medium priority RequestQueues use + // detector id as the key of the RequestQueue map. low and high priority requests + // just use the RequestQueue priority (i.e., low or high) as the key of the RequestQueue map. + RequestQueue requestQueue = requestQueues + .computeIfAbsent( + RequestPriority.MEDIUM == request.getPriority() ? request.getId() : request.getPriority().name(), + k -> new RequestQueue() + ); + + requestQueue.lastAccessTime = clock.instant(); + requestQueue.put(request); + } catch (Exception e) { + LOG.error(new ParameterizedMessage("Failed to add requests to [{}]", this.workerName), e); + } + } + + private void maintainForThreadPool() { + for (final ThreadPoolStats.Stats stats : threadPool.stats()) { + String name = stats.getName(); + // we mostly use these 3 threadpools + if (ThreadPool.Names.SEARCH.equals(name) || ThreadPool.Names.GET.equals(name) || ThreadPool.Names.WRITE.equals(name)) { + int maxQueueSize = (int) (maxQueuedTaskRatio * threadPool.info(name).getQueueSize().singles()); + // in case that users set queue size to -1 (unbounded) + if (maxQueueSize > 0 && stats.getQueue() > maxQueueSize) { + LOG.info(new ParameterizedMessage("Queue [{}] size [{}], reached limit [{}]", name, stats.getQueue(), maxQueueSize)); + setCoolDownStart(); + break; + } + } + } + } + + private void prune(Map requestQueues) { + // pruning expired requests + pruneExpired(); + + // prune a few requests in each queue + for (Map.Entry requestQueueEntry : requestQueues.entrySet()) { + if (requestQueueEntry.getKey().equals(RequestPriority.HIGH.name())) { + continue; + } + + RequestQueue requestQueue = requestQueueEntry.getValue(); + + if (requestQueue == null || requestQueue.isEmpty()) { + continue; + } + + // remove more requests in the low priority RequestQueue + if (requestQueueEntry.getKey().equals(RequestPriority.LOW.name())) { + requestQueue.drain(lowRequestQueuePruneRatio); + } else { + requestQueue.drain(mediumRequestQueuePruneRatio); + } + } + } + + /** + * pruning expired requests + * + * @return the total number of deleted requests + */ + private int pruneExpired() { + int deleted = 0; + for (Map.Entry requestQueueEntry : requestQueues.entrySet()) { + RequestQueue requestQueue = requestQueueEntry.getValue(); + + if (requestQueue == null) { + continue; + } + + deleted += requestQueue.clearExpiredRequests(); + } + return deleted; + } + + private void prune(Map requestQueues, int exceededSize) { + // pruning expired requests + int leftItemsToRemove = exceededSize - pruneExpired(); + + if (leftItemsToRemove <= 0) { + return; + } + + // used to compute the average number of requests to remove in medium priority queues + int numberOfQueuesToExclude = 0; + + // remove low-priority requests + RequestQueue requestQueue = requestQueues.get(RequestPriority.LOW.name()); + if (requestQueue != null) { + int removedFromLow = requestQueue.drain(leftItemsToRemove); + if (removedFromLow >= leftItemsToRemove) { + return; + } else { + numberOfQueuesToExclude++; + leftItemsToRemove -= removedFromLow; + } + } + + // skip high-priority requests + if (requestQueues.get(RequestPriority.HIGH.name()) != null) { + numberOfQueuesToExclude++; + } + + int numberOfRequestsToRemoveInMediumQueues = leftItemsToRemove / (requestQueues.size() - numberOfQueuesToExclude); + + for (Map.Entry requestQueueEntry : requestQueues.entrySet()) { + if (requestQueueEntry.getKey().equals(RequestPriority.HIGH.name()) + || requestQueueEntry.getKey().equals(RequestPriority.LOW.name())) { + continue; + } + + requestQueue = requestQueueEntry.getValue(); + + if (requestQueue == null) { + continue; + } + + requestQueue.drain(numberOfRequestsToRemoveInMediumQueues); + } + } + + private void maintainForMemory() { + // removed expired RequestQueue + maintenance(requestQueues, stateTtl); + + int exceededSize = exceededSize(); + if (exceededSize > 0) { + prune(requestQueues, exceededSize); + } else if (adCircuitBreakerService.isOpen()) { + // remove a few items in each RequestQueue + prune(requestQueues); + } + } + + private int exceededSize() { + Collection queues = requestQueues.values(); + int totalSize = 0; + + // When faced with a backlog beyond the limit, we prefer fresh requests + // and throws away old requests. + // release space so that put won't block + for (RequestQueue q : queues) { + totalSize += q.size(); + } + return totalSize - queueSize; + } + + public boolean isQueueEmpty() { + Collection queues = requestQueues.values(); + for (RequestQueue q : queues) { + if (q.size() > 0) { + return false; + } + } + return true; + } + + @Override + public void maintenance() { + try { + maintainForMemory(); + maintainForThreadPool(); + } catch (Exception e) { + LOG.warn("Failed to maintain", e); + } + } + + /** + * Start cooldown during a overloaded situation + */ + protected void setCoolDownStart() { + cooldownStart = clock.instant(); + } + + /** + * @param batchSize the max number of requests to fetch + * @return a list of batchSize requests (can be less) + */ + protected List getRequests(int batchSize) { + List toProcess = new ArrayList<>(batchSize); + + Set> selectedQueue = new HashSet<>(); + + while (toProcess.size() < batchSize) { + Optional> queue = selectNextQueue(); + if (false == queue.isPresent()) { + // no queue has requests + break; + } + + BlockingQueue nextToProcess = queue.get(); + if (selectedQueue.contains(nextToProcess)) { + // we have gone around all of the queues + break; + } + selectedQueue.add(nextToProcess); + + List requests = new ArrayList<>(); + // concurrent requests will wait to prevent concurrent draining. + // This is fine since the operation is fast + nextToProcess.drainTo(requests, batchSize); + toProcess.addAll(requests); + } + + return toProcess; + } + + /** + * Enqueuing runs asynchronously: we put requests in a queue, try to execute + * them. The thread executing requests won't block the thread inserting + * requests to the queue. + * @param request Individual request + */ + public void put(RequestType request) { + if (request == null) { + return; + } + putOnly(request); + + process(); + } + + public void putAll(List requests) { + if (requests == null || requests.isEmpty()) { + return; + } + try { + for (RequestType request : requests) { + putOnly(request); + } + + process(); + } catch (Exception e) { + LOG.error(new ParameterizedMessage("Failed to add requests to [{}]", getWorkerName()), e); + } + } + + protected void process() { + if (random.nextInt(maintenanceFreqConstant) == 1) { + maintenance(); + } + + // still in cooldown period + if (cooldownStart.plus(Duration.ofMinutes(coolDownMinutes)).isAfter(clock.instant())) { + threadPool.schedule(() -> { + try { + process(); + } catch (Exception e) { + LOG.error(new ParameterizedMessage("Fail to process requests in [{}].", this.workerName), e); + } + }, new TimeValue(coolDownMinutes, TimeUnit.MINUTES), TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME); + } else { + try { + triggerProcess(); + } catch (Exception e) { + LOG.error(String.format(Locale.ROOT, "Failed to process requests from %s", getWorkerName()), e); + if (e != null && e instanceof TimeSeriesException) { + TimeSeriesException adExep = (TimeSeriesException) e; + nodeStateManager.setException(adExep.getConfigId(), adExep); + } + } + + } + } + + /** + * How to execute requests is abstracted out and left to RateLimitedQueue's subclasses to implement. + */ + protected abstract void triggerProcess(); +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/RequestPriority.java-e b/src/main/java/org/opensearch/ad/ratelimit/RequestPriority.java-e new file mode 100644 index 000000000..3193d2285 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/RequestPriority.java-e @@ -0,0 +1,18 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +public enum RequestPriority { + LOW, + MEDIUM, + HIGH +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ResultWriteRequest.java b/src/main/java/org/opensearch/ad/ratelimit/ResultWriteRequest.java index 7acef66a7..a25bf3924 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/ResultWriteRequest.java +++ b/src/main/java/org/opensearch/ad/ratelimit/ResultWriteRequest.java @@ -14,9 +14,9 @@ import java.io.IOException; import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; public class ResultWriteRequest extends QueuedRequest implements Writeable { private final AnomalyResult result; diff --git a/src/main/java/org/opensearch/ad/ratelimit/ResultWriteRequest.java-e b/src/main/java/org/opensearch/ad/ratelimit/ResultWriteRequest.java-e new file mode 100644 index 000000000..a25bf3924 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ResultWriteRequest.java-e @@ -0,0 +1,56 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import java.io.IOException; + +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; + +public class ResultWriteRequest extends QueuedRequest implements Writeable { + private final AnomalyResult result; + // If resultIndex is null, result will be stored in default result index. + private final String resultIndex; + + public ResultWriteRequest( + long expirationEpochMs, + String detectorId, + RequestPriority priority, + AnomalyResult result, + String resultIndex + ) { + super(expirationEpochMs, detectorId, priority); + this.result = result; + this.resultIndex = resultIndex; + } + + public ResultWriteRequest(StreamInput in) throws IOException { + this.result = new AnomalyResult(in); + this.resultIndex = in.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + result.writeTo(out); + out.writeOptionalString(resultIndex); + } + + public AnomalyResult getResult() { + return result; + } + + public String getCustomResultIndex() { + return resultIndex; + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ResultWriteWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ResultWriteWorker.java index dad1b409b..2381e5db9 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/ResultWriteWorker.java +++ b/src/main/java/org/opensearch/ad/ratelimit/ResultWriteWorker.java @@ -35,12 +35,12 @@ import org.opensearch.ad.transport.handler.MultiEntityResultHandler; import org.opensearch.ad.util.ExceptionUtil; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.threadpool.ThreadPool; @@ -206,7 +206,7 @@ private Optional getAnomalyResult(DocWriteRequest request) { .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, indexSource, indexContentType) ) { // the first character is null. Without skipping it, we get - // org.opensearch.common.ParsingException: Failed to parse object: expecting token of type [START_OBJECT] but found + // org.opensearch.core.common.ParsingException: Failed to parse object: expecting token of type [START_OBJECT] but found // [null] xContentParser.nextToken(); return Optional.of(AnomalyResult.parse(xContentParser)); diff --git a/src/main/java/org/opensearch/ad/ratelimit/ResultWriteWorker.java-e b/src/main/java/org/opensearch/ad/ratelimit/ResultWriteWorker.java-e new file mode 100644 index 000000000..d2d22078b --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ResultWriteWorker.java-e @@ -0,0 +1,219 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_BATCH_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.util.List; +import java.util.Optional; +import java.util.Random; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.ActionListener; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.transport.ADResultBulkRequest; +import org.opensearch.ad.transport.ADResultBulkResponse; +import org.opensearch.ad.transport.handler.MultiEntityResultHandler; +import org.opensearch.ad.util.ExceptionUtil; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.threadpool.ThreadPool; + +public class ResultWriteWorker extends BatchWorker { + private static final Logger LOG = LogManager.getLogger(ResultWriteWorker.class); + public static final String WORKER_NAME = "result-write"; + + private final MultiEntityResultHandler resultHandler; + private NamedXContentRegistry xContentRegistry; + + public ResultWriteWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + ADCircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + MultiEntityResultHandler resultHandler, + NamedXContentRegistry xContentRegistry, + NodeStateManager stateManager, + Duration stateTtl + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + AD_RESULT_WRITE_QUEUE_CONCURRENCY, + executionTtl, + AD_RESULT_WRITE_QUEUE_BATCH_SIZE, + stateTtl, + stateManager + ); + this.resultHandler = resultHandler; + this.xContentRegistry = xContentRegistry; + } + + @Override + protected void executeBatchRequest(ADResultBulkRequest request, ActionListener listener) { + if (request.numberOfActions() < 1) { + listener.onResponse(null); + return; + } + resultHandler.flush(request, listener); + } + + @Override + protected ADResultBulkRequest toBatchRequest(List toProcess) { + final ADResultBulkRequest bulkRequest = new ADResultBulkRequest(); + for (ResultWriteRequest request : toProcess) { + bulkRequest.add(request); + } + return bulkRequest; + } + + @Override + protected ActionListener getResponseListener( + List toProcess, + ADResultBulkRequest bulkRequest + ) { + return ActionListener.wrap(adResultBulkResponse -> { + if (adResultBulkResponse == null || false == adResultBulkResponse.getRetryRequests().isPresent()) { + // all successful + return; + } + + enqueueRetryRequestIteration(adResultBulkResponse.getRetryRequests().get(), 0); + }, exception -> { + if (ExceptionUtil.isRetryAble(exception)) { + // retry all of them + super.putAll(toProcess); + } else if (ExceptionUtil.isOverloaded(exception)) { + LOG.error("too many get AD model checkpoint requests or shard not avialble"); + setCoolDownStart(); + } + + for (ResultWriteRequest request : toProcess) { + nodeStateManager.setException(request.getId(), exception); + } + LOG.error("Fail to save results", exception); + }); + } + + private void enqueueRetryRequestIteration(List requestToRetry, int index) { + if (index >= requestToRetry.size()) { + return; + } + DocWriteRequest currentRequest = requestToRetry.get(index); + Optional resultToRetry = getAnomalyResult(currentRequest); + if (false == resultToRetry.isPresent()) { + enqueueRetryRequestIteration(requestToRetry, index + 1); + return; + } + AnomalyResult result = resultToRetry.get(); + String detectorId = result.getConfigId(); + nodeStateManager.getAnomalyDetector(detectorId, onGetDetector(requestToRetry, index, detectorId, result)); + } + + private ActionListener> onGetDetector( + List requestToRetry, + int index, + String detectorId, + AnomalyResult resultToRetry + ) { + return ActionListener.wrap(detectorOptional -> { + if (false == detectorOptional.isPresent()) { + LOG.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); + enqueueRetryRequestIteration(requestToRetry, index + 1); + return; + } + + AnomalyDetector detector = detectorOptional.get(); + super.put( + new ResultWriteRequest( + // expire based on execute start time + resultToRetry.getExecutionStartTime().toEpochMilli() + detector.getIntervalInMilliseconds(), + detectorId, + resultToRetry.isHighPriority() ? RequestPriority.HIGH : RequestPriority.MEDIUM, + resultToRetry, + detector.getCustomResultIndex() + ) + ); + + enqueueRetryRequestIteration(requestToRetry, index + 1); + + }, exception -> { + LOG.error(new ParameterizedMessage("fail to get detector [{}]", detectorId), exception); + enqueueRetryRequestIteration(requestToRetry, index + 1); + }); + } + + private Optional getAnomalyResult(DocWriteRequest request) { + try { + if (false == (request instanceof IndexRequest)) { + LOG.error(new ParameterizedMessage("We should only send IndexRquest, but get [{}].", request)); + return Optional.empty(); + } + // we send IndexRequest previously + IndexRequest indexRequest = (IndexRequest) request; + BytesReference indexSource = indexRequest.source(); + XContentType indexContentType = indexRequest.getContentType(); + try ( + XContentParser xContentParser = XContentHelper + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, indexSource, indexContentType) + ) { + // the first character is null. Without skipping it, we get + // org.opensearch.core.common.ParsingException: Failed to parse object: expecting token of type [START_OBJECT] but found + // [null] + xContentParser.nextToken(); + return Optional.of(AnomalyResult.parse(xContentParser)); + } + } catch (Exception e) { + LOG.error(new ParameterizedMessage("Fail to parse index request [{}]", request), e); + } + return Optional.empty(); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ScheduledWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ScheduledWorker.java index 1bfeec9af..9d4891b7c 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/ScheduledWorker.java +++ b/src/main/java/org/opensearch/ad/ratelimit/ScheduledWorker.java @@ -18,7 +18,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.cluster.service.ClusterService; @@ -26,6 +25,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; public abstract class ScheduledWorker extends RateLimitedRequestWorker { @@ -114,7 +114,7 @@ private void pullRequests() { private synchronized void schedulePulling(TimeValue delay) { try { - threadPool.schedule(this::pullRequests, delay, AnomalyDetectorPlugin.AD_THREAD_POOL_NAME); + threadPool.schedule(this::pullRequests, delay, TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME); } catch (Exception e) { LOG.error("Fail to schedule cold entity pulling", e); } diff --git a/src/main/java/org/opensearch/ad/ratelimit/ScheduledWorker.java-e b/src/main/java/org/opensearch/ad/ratelimit/ScheduledWorker.java-e new file mode 100644 index 000000000..9d4891b7c --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ScheduledWorker.java-e @@ -0,0 +1,151 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import java.time.Clock; +import java.time.Duration; +import java.util.List; +import java.util.Random; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +public abstract class ScheduledWorker extends + RateLimitedRequestWorker { + private static final Logger LOG = LogManager.getLogger(ColdEntityWorker.class); + + // the number of requests forwarded to the target queue + protected volatile int batchSize; + private final RateLimitedRequestWorker targetQueue; + // indicate whether a future pull over cold entity queues is scheduled + private boolean scheduled; + protected volatile int expectedExecutionTimeInMilliSecsPerRequest; + + public ScheduledWorker( + String workerName, + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + ADCircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + RateLimitedRequestWorker targetQueue, + Duration stateTtl, + NodeStateManager nodeStateManager + ) { + super( + workerName, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + stateTtl, + nodeStateManager + ); + + this.targetQueue = targetQueue; + this.scheduled = false; + } + + private void pullRequests() { + int pulledRequestSize = 0; + int filteredRequestSize = 0; + try { + List requests = getRequests(batchSize); + if (requests == null || requests.isEmpty()) { + return; + } + pulledRequestSize = requests.size(); + List filteredRequests = transformRequests(requests); + if (!filteredRequests.isEmpty()) { + targetQueue.putAll(filteredRequests); + filteredRequestSize = filteredRequests.size(); + } + } catch (Exception e) { + LOG.error("Error enqueuing cold entity requests", e); + } finally { + if (pulledRequestSize < batchSize) { + scheduled = false; + } else { + // there might be more to fetch + // schedule a pull from queue every few seconds. + scheduled = true; + if (filteredRequestSize == 0) { + pullRequests(); + } else { + schedulePulling(getScheduleDelay(filteredRequestSize)); + } + } + } + } + + private synchronized void schedulePulling(TimeValue delay) { + try { + threadPool.schedule(this::pullRequests, delay, TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME); + } catch (Exception e) { + LOG.error("Fail to schedule cold entity pulling", e); + } + } + + /** + * only pull requests to process when there's no other scheduled run + */ + @Override + protected void triggerProcess() { + if (false == scheduled) { + pullRequests(); + } + } + + /** + * The method calculates the delay we have to set to control the rate of cold + * entity processing. We wait longer if the requestSize is larger to give the + * system more time to processing requests. + * @param requestSize requests to process + * @return the delay for the next scheduled run + */ + private TimeValue getScheduleDelay(int requestSize) { + return TimeValue.timeValueMillis(requestSize * expectedExecutionTimeInMilliSecsPerRequest); + } + + /** + * Transform requests before forwarding to another queue + * @param requests requests to be transformed + * + * @return processed requests + */ + protected abstract List transformRequests(List requests); +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/SingleRequestWorker.java-e b/src/main/java/org/opensearch/ad/ratelimit/SingleRequestWorker.java-e new file mode 100644 index 000000000..028a0643f --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/SingleRequestWorker.java-e @@ -0,0 +1,107 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import java.time.Clock; +import java.time.Duration; +import java.util.Optional; +import java.util.Random; +import java.util.concurrent.BlockingQueue; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; + +public abstract class SingleRequestWorker extends ConcurrentWorker { + private static final Logger LOG = LogManager.getLogger(SingleRequestWorker.class); + + public SingleRequestWorker( + String queueName, + long heapSize, + int singleRequestSize, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + ADCircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Setting concurrencySetting, + Duration executionTtl, + Duration stateTtl, + NodeStateManager nodeStateManager + ) { + super( + queueName, + heapSize, + singleRequestSize, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + concurrencySetting, + executionTtl, + stateTtl, + nodeStateManager + ); + } + + @Override + protected void execute(Runnable afterProcessCallback, Runnable emptyQueueCallback) { + RequestType request = null; + + Optional> queueOptional = selectNextQueue(); + if (false == queueOptional.isPresent()) { + // no queue has requests + emptyQueueCallback.run(); + return; + } + + BlockingQueue queue = queueOptional.get(); + if (false == queue.isEmpty()) { + request = queue.poll(); + } + + if (request == null) { + emptyQueueCallback.run(); + return; + } + + final ActionListener handlerWithRelease = ActionListener.wrap(afterProcessCallback); + executeRequest(request, handlerWithRelease); + } + + /** + * Used by subclasses to creates customized logic to send batch requests. + * After everything finishes, the method should call listener. + * @param request request to execute + * @param listener customized listener + */ + protected abstract void executeRequest(RequestType request, ActionListener listener); +} diff --git a/src/main/java/org/opensearch/ad/rest/AbstractAnomalyDetectorAction.java-e b/src/main/java/org/opensearch/ad/rest/AbstractAnomalyDetectorAction.java-e new file mode 100644 index 000000000..331c3151f --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/AbstractAnomalyDetectorAction.java-e @@ -0,0 +1,55 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.DETECTION_INTERVAL; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.DETECTION_WINDOW_DELAY; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_ANOMALY_FEATURES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.REQUEST_TIMEOUT; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.rest.BaseRestHandler; + +public abstract class AbstractAnomalyDetectorAction extends BaseRestHandler { + + protected volatile TimeValue requestTimeout; + protected volatile TimeValue detectionInterval; + protected volatile TimeValue detectionWindowDelay; + protected volatile Integer maxSingleEntityDetectors; + protected volatile Integer maxMultiEntityDetectors; + protected volatile Integer maxAnomalyFeatures; + + public AbstractAnomalyDetectorAction(Settings settings, ClusterService clusterService) { + this.requestTimeout = REQUEST_TIMEOUT.get(settings); + this.detectionInterval = DETECTION_INTERVAL.get(settings); + this.detectionWindowDelay = DETECTION_WINDOW_DELAY.get(settings); + this.maxSingleEntityDetectors = MAX_SINGLE_ENTITY_ANOMALY_DETECTORS.get(settings); + this.maxMultiEntityDetectors = MAX_MULTI_ENTITY_ANOMALY_DETECTORS.get(settings); + this.maxAnomalyFeatures = MAX_ANOMALY_FEATURES.get(settings); + // TODO: will add more cluster setting consumer later + // TODO: inject ClusterSettings only if clusterService is only used to get ClusterSettings + clusterService.getClusterSettings().addSettingsUpdateConsumer(REQUEST_TIMEOUT, it -> requestTimeout = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(DETECTION_INTERVAL, it -> detectionInterval = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(DETECTION_WINDOW_DELAY, it -> detectionWindowDelay = it); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(MAX_SINGLE_ENTITY_ANOMALY_DETECTORS, it -> maxSingleEntityDetectors = it); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(MAX_MULTI_ENTITY_ANOMALY_DETECTORS, it -> maxMultiEntityDetectors = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_ANOMALY_FEATURES, it -> maxAnomalyFeatures = it); + } +} diff --git a/src/main/java/org/opensearch/ad/rest/AbstractSearchAction.java b/src/main/java/org/opensearch/ad/rest/AbstractSearchAction.java index 5003a4739..1d0611cf7 100644 --- a/src/main/java/org/opensearch/ad/rest/AbstractSearchAction.java +++ b/src/main/java/org/opensearch/ad/rest/AbstractSearchAction.java @@ -27,13 +27,13 @@ import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.client.node.NodeClient; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestResponse; -import org.opensearch.rest.RestStatus; import org.opensearch.rest.action.RestResponseListener; import org.opensearch.search.builder.SearchSourceBuilder; diff --git a/src/main/java/org/opensearch/ad/rest/AbstractSearchAction.java-e b/src/main/java/org/opensearch/ad/rest/AbstractSearchAction.java-e new file mode 100644 index 000000000..28e870742 --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/AbstractSearchAction.java-e @@ -0,0 +1,127 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest; + +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; +import static org.opensearch.timeseries.util.RestHandlerUtils.getSourceContext; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestResponse; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.rest.action.RestResponseListener; +import org.opensearch.search.builder.SearchSourceBuilder; + +/** + * Abstract class to handle search request. + */ +public abstract class AbstractSearchAction extends BaseRestHandler { + + protected final String index; + protected final Class clazz; + protected final List urlPaths; + protected final List> deprecatedPaths; + protected final ActionType actionType; + + private final Logger logger = LogManager.getLogger(AbstractSearchAction.class); + + public AbstractSearchAction( + List urlPaths, + List> deprecatedPaths, + String index, + Class clazz, + ActionType actionType + ) { + this.index = index; + this.clazz = clazz; + this.urlPaths = urlPaths; + this.deprecatedPaths = deprecatedPaths; + this.actionType = actionType; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ADEnabledSetting.isADEnabled()) { + throw new IllegalStateException(ADCommonMessages.DISABLED_ERR_MSG); + } + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.parseXContent(request.contentOrSourceParamParser()); + // order of response will be re-arranged everytime we use `_source`, we sometimes do this + // even if user doesn't give this field as we exclude ui_metadata if request isn't from OSD + // ref-link: https://github.com/elastic/elasticsearch/issues/17639 + searchSourceBuilder.fetchSource(getSourceContext(request, searchSourceBuilder)); + searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true); + SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(this.index); + return channel -> client.execute(actionType, searchRequest, search(channel)); + } + + protected void onFailure(RestChannel channel, Exception e) { + try { + channel.sendResponse(new BytesRestResponse(channel, e)); + } catch (Exception exception) { + logger.error("Failed to send back failure response for search AD result", exception); + } + } + + protected RestResponseListener search(RestChannel channel) { + return new RestResponseListener(channel) { + @Override + public RestResponse buildResponse(SearchResponse response) throws Exception { + if (response.isTimedOut()) { + return new BytesRestResponse(RestStatus.REQUEST_TIMEOUT, response.toString()); + } + return new BytesRestResponse(RestStatus.OK, response.toXContent(channel.newBuilder(), EMPTY_PARAMS)); + } + }; + } + + @Override + public List routes() { + List routes = new ArrayList<>(); + for (String path : urlPaths) { + routes.add(new Route(RestRequest.Method.POST, path)); + routes.add(new Route(RestRequest.Method.GET, path)); + } + return routes; + } + + @Override + public List replacedRoutes() { + List replacedRoutes = new ArrayList<>(); + for (Pair deprecatedPath : deprecatedPaths) { + replacedRoutes + .add( + new ReplacedRoute(RestRequest.Method.POST, deprecatedPath.getKey(), RestRequest.Method.POST, deprecatedPath.getValue()) + ); + replacedRoutes + .add(new ReplacedRoute(RestRequest.Method.GET, deprecatedPath.getKey(), RestRequest.Method.GET, deprecatedPath.getValue())); + + } + return replacedRoutes; + } +} diff --git a/src/main/java/org/opensearch/ad/rest/RestAnomalyDetectorJobAction.java b/src/main/java/org/opensearch/ad/rest/RestAnomalyDetectorJobAction.java index 3272b3e23..a5052c84d 100644 --- a/src/main/java/org/opensearch/ad/rest/RestAnomalyDetectorJobAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestAnomalyDetectorJobAction.java @@ -12,7 +12,7 @@ package org.opensearch.ad.rest; import static org.opensearch.ad.settings.AnomalyDetectorSettings.REQUEST_TIMEOUT; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.timeseries.util.RestHandlerUtils.DETECTOR_ID; import static org.opensearch.timeseries.util.RestHandlerUtils.IF_PRIMARY_TERM; import static org.opensearch.timeseries.util.RestHandlerUtils.IF_SEQ_NO; @@ -23,7 +23,6 @@ import java.util.List; import java.util.Locale; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.transport.AnomalyDetectorJobAction; @@ -37,6 +36,7 @@ import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.model.DateRange; import com.google.common.collect.ImmutableList; @@ -107,16 +107,17 @@ public List replacedRoutes() { // start AD Job new ReplacedRoute( RestRequest.Method.POST, - String.format(Locale.ROOT, "%s/{%s}/%s", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, START_JOB), + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, START_JOB), RestRequest.Method.POST, - String.format(Locale.ROOT, "%s/{%s}/%s", AnomalyDetectorPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID, START_JOB) + String + .format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID, START_JOB) ), // stop AD Job new ReplacedRoute( RestRequest.Method.POST, - String.format(Locale.ROOT, "%s/{%s}/%s", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, STOP_JOB), + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, STOP_JOB), RestRequest.Method.POST, - String.format(Locale.ROOT, "%s/{%s}/%s", AnomalyDetectorPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID, STOP_JOB) + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID, STOP_JOB) ) ); } diff --git a/src/main/java/org/opensearch/ad/rest/RestAnomalyDetectorJobAction.java-e b/src/main/java/org/opensearch/ad/rest/RestAnomalyDetectorJobAction.java-e new file mode 100644 index 000000000..a5052c84d --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/RestAnomalyDetectorJobAction.java-e @@ -0,0 +1,124 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.REQUEST_TIMEOUT; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.util.RestHandlerUtils.DETECTOR_ID; +import static org.opensearch.timeseries.util.RestHandlerUtils.IF_PRIMARY_TERM; +import static org.opensearch.timeseries.util.RestHandlerUtils.IF_SEQ_NO; +import static org.opensearch.timeseries.util.RestHandlerUtils.START_JOB; +import static org.opensearch.timeseries.util.RestHandlerUtils.STOP_JOB; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.ad.transport.AnomalyDetectorJobAction; +import org.opensearch.ad.transport.AnomalyDetectorJobRequest; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.seqno.SequenceNumbers; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.model.DateRange; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to handle request to start/stop AD job. + */ +public class RestAnomalyDetectorJobAction extends BaseRestHandler { + + public static final String AD_JOB_ACTION = "anomaly_detector_job_action"; + private volatile TimeValue requestTimeout; + + public RestAnomalyDetectorJobAction(Settings settings, ClusterService clusterService) { + this.requestTimeout = REQUEST_TIMEOUT.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(REQUEST_TIMEOUT, it -> requestTimeout = it); + } + + @Override + public String getName() { + return AD_JOB_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ADEnabledSetting.isADEnabled()) { + throw new IllegalStateException(ADCommonMessages.DISABLED_ERR_MSG); + } + + String detectorId = request.param(DETECTOR_ID); + long seqNo = request.paramAsLong(IF_SEQ_NO, SequenceNumbers.UNASSIGNED_SEQ_NO); + long primaryTerm = request.paramAsLong(IF_PRIMARY_TERM, SequenceNumbers.UNASSIGNED_PRIMARY_TERM); + boolean historical = request.paramAsBoolean("historical", false); + String rawPath = request.rawPath(); + DateRange detectionDateRange = parseDetectionDateRange(request); + + AnomalyDetectorJobRequest anomalyDetectorJobRequest = new AnomalyDetectorJobRequest( + detectorId, + detectionDateRange, + historical, + seqNo, + primaryTerm, + rawPath + ); + + return channel -> client + .execute(AnomalyDetectorJobAction.INSTANCE, anomalyDetectorJobRequest, new RestToXContentListener<>(channel)); + } + + private DateRange parseDetectionDateRange(RestRequest request) throws IOException { + if (!request.hasContent()) { + return null; + } + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + DateRange dateRange = DateRange.parse(parser); + return dateRange; + } + + @Override + public List routes() { + return ImmutableList.of(); + } + + @Override + public List replacedRoutes() { + return ImmutableList + .of( + // start AD Job + new ReplacedRoute( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, START_JOB), + RestRequest.Method.POST, + String + .format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID, START_JOB) + ), + // stop AD Job + new ReplacedRoute( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, STOP_JOB), + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID, STOP_JOB) + ) + ); + } +} diff --git a/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyDetectorAction.java index ef59cbd6e..b7a3aae6c 100644 --- a/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyDetectorAction.java @@ -19,7 +19,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.rest.handler.AnomalyDetectorActionHandler; import org.opensearch.ad.settings.ADEnabledSetting; @@ -29,6 +28,7 @@ import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import com.google.common.collect.ImmutableList; @@ -73,9 +73,9 @@ public List replacedRoutes() { // delete anomaly detector document new ReplacedRoute( RestRequest.Method.DELETE, - String.format(Locale.ROOT, "%s/{%s}", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID), + String.format(Locale.ROOT, "%s/{%s}", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID), RestRequest.Method.DELETE, - String.format(Locale.ROOT, "%s/{%s}", AnomalyDetectorPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID) + String.format(Locale.ROOT, "%s/{%s}", TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID) ) ); } diff --git a/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyDetectorAction.java-e b/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyDetectorAction.java-e new file mode 100644 index 000000000..b7a3aae6c --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyDetectorAction.java-e @@ -0,0 +1,82 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest; + +import static org.opensearch.timeseries.util.RestHandlerUtils.DETECTOR_ID; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.rest.handler.AnomalyDetectorActionHandler; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.ad.transport.DeleteAnomalyDetectorAction; +import org.opensearch.ad.transport.DeleteAnomalyDetectorRequest; +import org.opensearch.client.node.NodeClient; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to delete anomaly detector. + */ +public class RestDeleteAnomalyDetectorAction extends BaseRestHandler { + + public static final String DELETE_ANOMALY_DETECTOR_ACTION = "delete_anomaly_detector"; + + private static final Logger logger = LogManager.getLogger(RestDeleteAnomalyDetectorAction.class); + private final AnomalyDetectorActionHandler handler = new AnomalyDetectorActionHandler(); + + public RestDeleteAnomalyDetectorAction() {} + + @Override + public String getName() { + return DELETE_ANOMALY_DETECTOR_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ADEnabledSetting.isADEnabled()) { + throw new IllegalStateException(ADCommonMessages.DISABLED_ERR_MSG); + } + + String detectorId = request.param(DETECTOR_ID); + DeleteAnomalyDetectorRequest deleteAnomalyDetectorRequest = new DeleteAnomalyDetectorRequest(detectorId); + return channel -> client + .execute(DeleteAnomalyDetectorAction.INSTANCE, deleteAnomalyDetectorRequest, new RestToXContentListener<>(channel)); + } + + @Override + public List routes() { + return ImmutableList.of(); + } + + @Override + public List replacedRoutes() { + return ImmutableList + .of( + // delete anomaly detector document + new ReplacedRoute( + RestRequest.Method.DELETE, + String.format(Locale.ROOT, "%s/{%s}", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID), + RestRequest.Method.DELETE, + String.format(Locale.ROOT, "%s/{%s}", TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID) + ) + ); + } +} diff --git a/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyResultsAction.java b/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyResultsAction.java index 5bb938ea9..a69570ce0 100644 --- a/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyResultsAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyResultsAction.java @@ -20,19 +20,19 @@ import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionListener; import org.opensearch.action.support.IndicesOptions; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.transport.DeleteAnomalyResultsAction; import org.opensearch.client.node.NodeClient; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.reindex.DeleteByQueryRequest; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; -import org.opensearch.rest.RestStatus; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import com.google.common.collect.ImmutableList; @@ -85,6 +85,6 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli @Override public List routes() { - return ImmutableList.of(new Route(RestRequest.Method.DELETE, AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI + "/results")); + return ImmutableList.of(new Route(RestRequest.Method.DELETE, TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI + "/results")); } } diff --git a/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyResultsAction.java-e b/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyResultsAction.java-e new file mode 100644 index 000000000..576b6aecf --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyResultsAction.java-e @@ -0,0 +1,90 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest; + +import static org.opensearch.ad.indices.ADIndexManagement.ALL_AD_RESULTS_INDEX_PATTERN; + +import java.io.IOException; +import java.util.List; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.ad.transport.DeleteAnomalyResultsAction; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestRequest; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to delete anomaly result with specific query. + * Currently AD dashboard plugin doesn't call this API. User can use this API to delete + * anomaly results to free up disk space. + * + * User needs to delete anomaly result from custom result index by themselves as they + * can directly access these custom result index. + * Same strategy for custom result index rollover. Suggest user using ISM plugin to + * manage custom result index. + * + * TODO: build better user experience to reduce user's effort to maintain custom result index. + */ +public class RestDeleteAnomalyResultsAction extends BaseRestHandler { + + private static final String DELETE_AD_RESULTS_ACTION = "delete_anomaly_results"; + private static final Logger logger = LogManager.getLogger(RestDeleteAnomalyResultsAction.class); + + public RestDeleteAnomalyResultsAction() {} + + @Override + public String getName() { + return DELETE_AD_RESULTS_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ADEnabledSetting.isADEnabled()) { + throw new IllegalStateException(ADCommonMessages.DISABLED_ERR_MSG); + } + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.parseXContent(request.contentOrSourceParamParser()); + DeleteByQueryRequest deleteRequest = new DeleteByQueryRequest(ALL_AD_RESULTS_INDEX_PATTERN) + .setQuery(searchSourceBuilder.query()) + .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN_HIDDEN); + return channel -> client.execute(DeleteAnomalyResultsAction.INSTANCE, deleteRequest, ActionListener.wrap(r -> { + XContentBuilder contentBuilder = r.toXContent(channel.newBuilder().startObject(), ToXContent.EMPTY_PARAMS); + contentBuilder.endObject(); + channel.sendResponse(new BytesRestResponse(RestStatus.OK, contentBuilder)); + }, e -> { + try { + channel.sendResponse(new BytesRestResponse(channel, e)); + } catch (IOException exception) { + logger.error("Failed to send back delete anomaly result exception result", exception); + } + })); + } + + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.DELETE, TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI + "/results")); + } +} diff --git a/src/main/java/org/opensearch/ad/rest/RestExecuteAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestExecuteAnomalyDetectorAction.java index cafa1ee61..fe0d10ec9 100644 --- a/src/main/java/org/opensearch/ad/rest/RestExecuteAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestExecuteAnomalyDetectorAction.java @@ -12,7 +12,7 @@ package org.opensearch.ad.rest; import static org.opensearch.ad.settings.AnomalyDetectorSettings.REQUEST_TIMEOUT; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.timeseries.util.RestHandlerUtils.DETECTOR_ID; import static org.opensearch.timeseries.util.RestHandlerUtils.RUN; @@ -23,7 +23,6 @@ import org.apache.commons.lang.StringUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.model.AnomalyDetectorExecutionInput; import org.opensearch.ad.settings.ADEnabledSetting; @@ -33,12 +32,13 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; -import org.opensearch.rest.RestStatus; import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import com.google.common.collect.ImmutableList; @@ -125,9 +125,9 @@ public List replacedRoutes() { // get AD result, for regular run new ReplacedRoute( RestRequest.Method.POST, - String.format(Locale.ROOT, "%s/{%s}/%s", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, RUN), + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, RUN), RestRequest.Method.POST, - String.format(Locale.ROOT, "%s/{%s}/%s", AnomalyDetectorPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID, RUN) + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID, RUN) ) ); } diff --git a/src/main/java/org/opensearch/ad/rest/RestExecuteAnomalyDetectorAction.java-e b/src/main/java/org/opensearch/ad/rest/RestExecuteAnomalyDetectorAction.java-e new file mode 100644 index 000000000..f5ac868eb --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/RestExecuteAnomalyDetectorAction.java-e @@ -0,0 +1,134 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.REQUEST_TIMEOUT; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.util.RestHandlerUtils.DETECTOR_ID; +import static org.opensearch.timeseries.util.RestHandlerUtils.RUN; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.apache.commons.lang.StringUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.model.AnomalyDetectorExecutionInput; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.ad.transport.AnomalyResultAction; +import org.opensearch.ad.transport.AnomalyResultRequest; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestRequest; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to handle request to detect data. + */ +public class RestExecuteAnomalyDetectorAction extends BaseRestHandler { + + public static final String DETECT_DATA_ACTION = "execute_anomaly_detector"; + // TODO: apply timeout config + private volatile TimeValue requestTimeout; + + private final Logger logger = LogManager.getLogger(RestExecuteAnomalyDetectorAction.class); + + public RestExecuteAnomalyDetectorAction(Settings settings, ClusterService clusterService) { + this.requestTimeout = REQUEST_TIMEOUT.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(REQUEST_TIMEOUT, it -> requestTimeout = it); + } + + @Override + public String getName() { + return DETECT_DATA_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ADEnabledSetting.isADEnabled()) { + throw new IllegalStateException(ADCommonMessages.DISABLED_ERR_MSG); + } + AnomalyDetectorExecutionInput input = getAnomalyDetectorExecutionInput(request); + return channel -> { + String error = validateAdExecutionInput(input); + if (StringUtils.isNotBlank(error)) { + channel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, error)); + return; + } + + AnomalyResultRequest getRequest = new AnomalyResultRequest( + input.getDetectorId(), + input.getPeriodStart().toEpochMilli(), + input.getPeriodEnd().toEpochMilli() + ); + client.execute(AnomalyResultAction.INSTANCE, getRequest, new RestToXContentListener<>(channel)); + }; + } + + private AnomalyDetectorExecutionInput getAnomalyDetectorExecutionInput(RestRequest request) throws IOException { + String detectorId = null; + if (request.hasParam(DETECTOR_ID)) { + detectorId = request.param(DETECTOR_ID); + } + + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + AnomalyDetectorExecutionInput input = AnomalyDetectorExecutionInput.parse(parser, detectorId); + if (detectorId != null) { + input.setDetectorId(detectorId); + } + return input; + } + + private String validateAdExecutionInput(AnomalyDetectorExecutionInput input) { + if (StringUtils.isBlank(input.getDetectorId()) && input.getDetector() == null) { + return "Must set anomaly detector id or detector"; + } + if (input.getPeriodStart() == null || input.getPeriodEnd() == null) { + return "Must set both period start and end date with epoch of milliseconds"; + } + if (!input.getPeriodStart().isBefore(input.getPeriodEnd())) { + return "Period start date should be before end date"; + } + return null; + } + + @Override + public List routes() { + return ImmutableList.of(); + } + + @Override + public List replacedRoutes() { + return ImmutableList + .of( + // get AD result, for regular run + new ReplacedRoute( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, RUN), + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID, RUN) + ) + ); + } +} diff --git a/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java index c7864cca4..d14ff85ce 100644 --- a/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java @@ -22,7 +22,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.settings.ADEnabledSetting; @@ -34,6 +33,7 @@ import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestActions; import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Entity; @@ -88,40 +88,49 @@ public List routes() { // Opensearch-only API. Considering users may provide entity in the search body, support POST as well. new Route( RestRequest.Method.POST, - String.format(Locale.ROOT, "%s/{%s}/%s", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE) + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE) ), new Route( RestRequest.Method.POST, - String.format(Locale.ROOT, "%s/{%s}/%s/{%s}", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE, TYPE) + String + .format(Locale.ROOT, "%s/{%s}/%s/{%s}", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE, TYPE) ) ); } @Override public List replacedRoutes() { - String path = String.format(Locale.ROOT, "%s/{%s}", AnomalyDetectorPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID); - String newPath = String.format(Locale.ROOT, "%s/{%s}", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID); + String path = String.format(Locale.ROOT, "%s/{%s}", TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID); + String newPath = String.format(Locale.ROOT, "%s/{%s}", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID); return ImmutableList .of( new ReplacedRoute(RestRequest.Method.GET, newPath, RestRequest.Method.GET, path), new ReplacedRoute(RestRequest.Method.HEAD, newPath, RestRequest.Method.HEAD, path), new ReplacedRoute( RestRequest.Method.GET, - String.format(Locale.ROOT, "%s/{%s}/%s", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE), + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE), RestRequest.Method.GET, - String.format(Locale.ROOT, "%s/{%s}/%s", AnomalyDetectorPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID, PROFILE) + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID, PROFILE) ), // types is a profile names. See a complete list of supported profiles names in // org.opensearch.ad.model.ProfileName. new ReplacedRoute( RestRequest.Method.GET, - String.format(Locale.ROOT, "%s/{%s}/%s/{%s}", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE, TYPE), + String + .format( + Locale.ROOT, + "%s/{%s}/%s/{%s}", + TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, + DETECTOR_ID, + PROFILE, + TYPE + ), RestRequest.Method.GET, String .format( Locale.ROOT, "%s/{%s}/%s/{%s}", - AnomalyDetectorPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, + TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID, PROFILE, TYPE diff --git a/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java-e b/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java-e new file mode 100644 index 000000000..d14ff85ce --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java-e @@ -0,0 +1,172 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest; + +import static org.opensearch.timeseries.util.RestHandlerUtils.DETECTOR_ID; +import static org.opensearch.timeseries.util.RestHandlerUtils.PROFILE; +import static org.opensearch.timeseries.util.RestHandlerUtils.TYPE; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.ad.transport.GetAnomalyDetectorAction; +import org.opensearch.ad.transport.GetAnomalyDetectorRequest; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.common.Strings; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestActions; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Entity; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to retrieve an anomaly detector. + */ +public class RestGetAnomalyDetectorAction extends BaseRestHandler { + + private static final String GET_ANOMALY_DETECTOR_ACTION = "get_anomaly_detector"; + private static final Logger logger = LogManager.getLogger(RestGetAnomalyDetectorAction.class); + + public RestGetAnomalyDetectorAction() {} + + @Override + public String getName() { + return GET_ANOMALY_DETECTOR_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ADEnabledSetting.isADEnabled()) { + throw new IllegalStateException(ADCommonMessages.DISABLED_ERR_MSG); + } + String detectorId = request.param(DETECTOR_ID); + String typesStr = request.param(TYPE); + + String rawPath = request.rawPath(); + boolean returnJob = request.paramAsBoolean("job", false); + boolean returnTask = request.paramAsBoolean("task", false); + boolean all = request.paramAsBoolean("_all", false); + GetAnomalyDetectorRequest getAnomalyDetectorRequest = new GetAnomalyDetectorRequest( + detectorId, + RestActions.parseVersion(request), + returnJob, + returnTask, + typesStr, + rawPath, + all, + buildEntity(request, detectorId) + ); + + return channel -> client + .execute(GetAnomalyDetectorAction.INSTANCE, getAnomalyDetectorRequest, new RestToXContentListener<>(channel)); + } + + @Override + public List routes() { + return ImmutableList + .of( + // Opensearch-only API. Considering users may provide entity in the search body, support POST as well. + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE) + ), + new Route( + RestRequest.Method.POST, + String + .format(Locale.ROOT, "%s/{%s}/%s/{%s}", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE, TYPE) + ) + ); + } + + @Override + public List replacedRoutes() { + String path = String.format(Locale.ROOT, "%s/{%s}", TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID); + String newPath = String.format(Locale.ROOT, "%s/{%s}", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID); + return ImmutableList + .of( + new ReplacedRoute(RestRequest.Method.GET, newPath, RestRequest.Method.GET, path), + new ReplacedRoute(RestRequest.Method.HEAD, newPath, RestRequest.Method.HEAD, path), + new ReplacedRoute( + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID, PROFILE), + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID, PROFILE) + ), + // types is a profile names. See a complete list of supported profiles names in + // org.opensearch.ad.model.ProfileName. + new ReplacedRoute( + RestRequest.Method.GET, + String + .format( + Locale.ROOT, + "%s/{%s}/%s/{%s}", + TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, + DETECTOR_ID, + PROFILE, + TYPE + ), + RestRequest.Method.GET, + String + .format( + Locale.ROOT, + "%s/{%s}/%s/{%s}", + TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, + DETECTOR_ID, + PROFILE, + TYPE + ) + ) + ); + } + + private Entity buildEntity(RestRequest request, String detectorId) throws IOException { + if (Strings.isEmpty(detectorId)) { + throw new IllegalStateException(ADCommonMessages.AD_ID_MISSING_MSG); + } + + String entityName = request.param(ADCommonName.CATEGORICAL_FIELD); + String entityValue = request.param(CommonName.ENTITY_KEY); + + if (entityName != null && entityValue != null) { + // single-stream profile request: + // GET _plugins/_anomaly_detection/detectors//_profile/init_progress?category_field=&entity= + return Entity.createSingleAttributeEntity(entityName, entityValue); + } else if (request.hasContent()) { + /* HCAD profile request: + * GET _plugins/_anomaly_detection/detectors//_profile/init_progress + * { + * "entity": [{ + * "name": "clientip", + * "value": "13.24.0.0" + * }] + * } + */ + Optional entity = Entity.fromJsonObject(request.contentParser()); + if (entity.isPresent()) { + return entity.get(); + } + } + // not a valid profile request with correct entity information + return null; + } +} diff --git a/src/main/java/org/opensearch/ad/rest/RestIndexAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestIndexAnomalyDetectorAction.java index a921d09e1..6231d8e11 100644 --- a/src/main/java/org/opensearch/ad/rest/RestIndexAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestIndexAnomalyDetectorAction.java @@ -11,7 +11,7 @@ package org.opensearch.ad.rest; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.timeseries.util.RestHandlerUtils.DETECTOR_ID; import static org.opensearch.timeseries.util.RestHandlerUtils.IF_PRIMARY_TERM; import static org.opensearch.timeseries.util.RestHandlerUtils.IF_SEQ_NO; @@ -24,7 +24,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.action.support.WriteRequest; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.ADEnabledSetting; @@ -34,6 +33,7 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.seqno.SequenceNumbers; @@ -41,8 +41,8 @@ import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestResponse; -import org.opensearch.rest.RestStatus; import org.opensearch.rest.action.RestResponseListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import com.google.common.collect.ImmutableList; @@ -113,16 +113,16 @@ public List replacedRoutes() { // Create new ReplacedRoute( RestRequest.Method.POST, - AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, + TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, RestRequest.Method.POST, - AnomalyDetectorPlugin.LEGACY_OPENDISTRO_AD_BASE_URI + TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI ), // Update new ReplacedRoute( RestRequest.Method.PUT, - String.format(Locale.ROOT, "%s/{%s}", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID), + String.format(Locale.ROOT, "%s/{%s}", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID), RestRequest.Method.PUT, - String.format(Locale.ROOT, "%s/{%s}", AnomalyDetectorPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID) + String.format(Locale.ROOT, "%s/{%s}", TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID) ) ); } @@ -143,7 +143,7 @@ public RestResponse buildResponse(IndexAnomalyDetectorResponse response) throws response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS) ); if (restStatus == RestStatus.CREATED) { - String location = String.format(Locale.ROOT, "%s/%s", AnomalyDetectorPlugin.LEGACY_AD_BASE, response.getId()); + String location = String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.LEGACY_AD_BASE, response.getId()); bytesRestResponse.addHeader("Location", location); } return bytesRestResponse; diff --git a/src/main/java/org/opensearch/ad/rest/RestIndexAnomalyDetectorAction.java-e b/src/main/java/org/opensearch/ad/rest/RestIndexAnomalyDetectorAction.java-e new file mode 100644 index 000000000..9bb62d411 --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/RestIndexAnomalyDetectorAction.java-e @@ -0,0 +1,153 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.util.RestHandlerUtils.DETECTOR_ID; +import static org.opensearch.timeseries.util.RestHandlerUtils.IF_PRIMARY_TERM; +import static org.opensearch.timeseries.util.RestHandlerUtils.IF_SEQ_NO; +import static org.opensearch.timeseries.util.RestHandlerUtils.REFRESH; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.ad.transport.IndexAnomalyDetectorAction; +import org.opensearch.ad.transport.IndexAnomalyDetectorRequest; +import org.opensearch.ad.transport.IndexAnomalyDetectorResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.seqno.SequenceNumbers; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestResponse; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.rest.action.RestResponseListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +import com.google.common.collect.ImmutableList; + +/** + * Rest handlers to create and update anomaly detector. + */ +public class RestIndexAnomalyDetectorAction extends AbstractAnomalyDetectorAction { + + private static final String INDEX_ANOMALY_DETECTOR_ACTION = "index_anomaly_detector_action"; + private final Logger logger = LogManager.getLogger(RestIndexAnomalyDetectorAction.class); + + public RestIndexAnomalyDetectorAction(Settings settings, ClusterService clusterService) { + super(settings, clusterService); + } + + @Override + public String getName() { + return INDEX_ANOMALY_DETECTOR_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ADEnabledSetting.isADEnabled()) { + throw new IllegalStateException(ADCommonMessages.DISABLED_ERR_MSG); + } + + String detectorId = request.param(DETECTOR_ID, AnomalyDetector.NO_ID); + logger.info("AnomalyDetector {} action for detectorId {}", request.method(), detectorId); + + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + // TODO: check detection interval < modelTTL + AnomalyDetector detector = AnomalyDetector.parse(parser, detectorId, null, detectionInterval, detectionWindowDelay); + + long seqNo = request.paramAsLong(IF_SEQ_NO, SequenceNumbers.UNASSIGNED_SEQ_NO); + long primaryTerm = request.paramAsLong(IF_PRIMARY_TERM, SequenceNumbers.UNASSIGNED_PRIMARY_TERM); + WriteRequest.RefreshPolicy refreshPolicy = request.hasParam(REFRESH) + ? WriteRequest.RefreshPolicy.parse(request.param(REFRESH)) + : WriteRequest.RefreshPolicy.IMMEDIATE; + RestRequest.Method method = request.getHttpRequest().method(); + + IndexAnomalyDetectorRequest indexAnomalyDetectorRequest = new IndexAnomalyDetectorRequest( + detectorId, + seqNo, + primaryTerm, + refreshPolicy, + detector, + method, + requestTimeout, + maxSingleEntityDetectors, + maxMultiEntityDetectors, + maxAnomalyFeatures + ); + + return channel -> client + .execute(IndexAnomalyDetectorAction.INSTANCE, indexAnomalyDetectorRequest, indexAnomalyDetectorResponse(channel, method)); + } + + @Override + public List routes() { + return ImmutableList.of(); + } + + @Override + public List replacedRoutes() { + return ImmutableList + .of( + // Create + new ReplacedRoute( + RestRequest.Method.POST, + TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, + RestRequest.Method.POST, + TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI + ), + // Update + new ReplacedRoute( + RestRequest.Method.PUT, + String.format(Locale.ROOT, "%s/{%s}", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, DETECTOR_ID), + RestRequest.Method.PUT, + String.format(Locale.ROOT, "%s/{%s}", TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, DETECTOR_ID) + ) + ); + } + + private RestResponseListener indexAnomalyDetectorResponse( + RestChannel channel, + RestRequest.Method method + ) { + return new RestResponseListener(channel) { + @Override + public RestResponse buildResponse(IndexAnomalyDetectorResponse response) throws Exception { + RestStatus restStatus = RestStatus.CREATED; + if (method == RestRequest.Method.PUT) { + restStatus = RestStatus.OK; + } + BytesRestResponse bytesRestResponse = new BytesRestResponse( + restStatus, + response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS) + ); + if (restStatus == RestStatus.CREATED) { + String location = String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.LEGACY_AD_BASE, response.getId()); + bytesRestResponse.addHeader("Location", location); + } + return bytesRestResponse; + } + }; + } +} diff --git a/src/main/java/org/opensearch/ad/rest/RestPreviewAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestPreviewAnomalyDetectorAction.java index 6af609dfb..9c11cc1cc 100644 --- a/src/main/java/org/opensearch/ad/rest/RestPreviewAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestPreviewAnomalyDetectorAction.java @@ -11,7 +11,7 @@ package org.opensearch.ad.rest; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.timeseries.util.RestHandlerUtils.PREVIEW; import java.io.IOException; @@ -21,20 +21,20 @@ import org.apache.commons.lang.StringUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.model.AnomalyDetectorExecutionInput; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.transport.PreviewAnomalyDetectorAction; import org.opensearch.ad.transport.PreviewAnomalyDetectorRequest; import org.opensearch.core.common.Strings; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; -import org.opensearch.rest.RestStatus; import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.util.RestHandlerUtils; import com.google.common.collect.ImmutableList; @@ -109,7 +109,7 @@ public List routes() { // preview detector new Route( RestRequest.Method.POST, - String.format(Locale.ROOT, "%s/%s", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, PREVIEW) + String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, PREVIEW) ) ); } @@ -125,7 +125,7 @@ public List replacedRoutes() { .format( Locale.ROOT, "%s/{%s}/%s", - AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, + TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, RestHandlerUtils.DETECTOR_ID, PREVIEW ), @@ -134,7 +134,7 @@ public List replacedRoutes() { .format( Locale.ROOT, "%s/{%s}/%s", - AnomalyDetectorPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, + TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, RestHandlerUtils.DETECTOR_ID, PREVIEW ) diff --git a/src/main/java/org/opensearch/ad/rest/RestPreviewAnomalyDetectorAction.java-e b/src/main/java/org/opensearch/ad/rest/RestPreviewAnomalyDetectorAction.java-e new file mode 100644 index 000000000..6d53569e0 --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/RestPreviewAnomalyDetectorAction.java-e @@ -0,0 +1,144 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.util.RestHandlerUtils.PREVIEW; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.apache.commons.lang.StringUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.model.AnomalyDetectorExecutionInput; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.ad.transport.PreviewAnomalyDetectorAction; +import org.opensearch.ad.transport.PreviewAnomalyDetectorRequest; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.util.RestHandlerUtils; + +import com.google.common.collect.ImmutableList; + +public class RestPreviewAnomalyDetectorAction extends BaseRestHandler { + + public static final String PREVIEW_ANOMALY_DETECTOR_ACTION = "preview_anomaly_detector"; + + private static final Logger logger = LogManager.getLogger(RestPreviewAnomalyDetectorAction.class); + + public RestPreviewAnomalyDetectorAction() {} + + @Override + public String getName() { + return PREVIEW_ANOMALY_DETECTOR_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, org.opensearch.client.node.NodeClient client) throws IOException { + if (!ADEnabledSetting.isADEnabled()) { + throw new IllegalStateException(ADCommonMessages.DISABLED_ERR_MSG); + } + + AnomalyDetectorExecutionInput input = getAnomalyDetectorExecutionInput(request); + + return channel -> { + String rawPath = request.rawPath(); + String error = validateAdExecutionInput(input); + if (StringUtils.isNotBlank(error)) { + channel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, error)); + return; + } + PreviewAnomalyDetectorRequest previewRequest = new PreviewAnomalyDetectorRequest( + input.getDetector(), + input.getDetectorId(), + input.getPeriodStart(), + input.getPeriodEnd() + ); + client.execute(PreviewAnomalyDetectorAction.INSTANCE, previewRequest, new RestToXContentListener<>(channel)); + }; + } + + private AnomalyDetectorExecutionInput getAnomalyDetectorExecutionInput(RestRequest request) throws IOException { + String detectorId = null; + if (request.hasParam(RestHandlerUtils.DETECTOR_ID)) { + detectorId = request.param(RestHandlerUtils.DETECTOR_ID); + } + + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + AnomalyDetectorExecutionInput input = AnomalyDetectorExecutionInput.parse(parser, detectorId); + return input; + } + + private String validateAdExecutionInput(AnomalyDetectorExecutionInput input) { + if (input.getPeriodStart() == null || input.getPeriodEnd() == null) { + return "Must set both period start and end date with epoch of milliseconds"; + } + if (!input.getPeriodStart().isBefore(input.getPeriodEnd())) { + return "Period start date should be before end date"; + } + if (Strings.isEmpty(input.getDetectorId()) && input.getDetector() == null) { + return "Must set detector id or detector"; + } + return null; + } + + @Override + public List routes() { + return ImmutableList + .of( + // preview detector + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, PREVIEW) + ) + ); + } + + @Override + public List replacedRoutes() { + return ImmutableList + .of( + // Preview Detector + new ReplacedRoute( + RestRequest.Method.POST, + String + .format( + Locale.ROOT, + "%s/{%s}/%s", + TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, + RestHandlerUtils.DETECTOR_ID, + PREVIEW + ), + RestRequest.Method.POST, + String + .format( + Locale.ROOT, + "%s/{%s}/%s", + TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, + RestHandlerUtils.DETECTOR_ID, + PREVIEW + ) + ) + ); + } +} diff --git a/src/main/java/org/opensearch/ad/rest/RestSearchADTasksAction.java b/src/main/java/org/opensearch/ad/rest/RestSearchADTasksAction.java index 08e627d19..6a1bfce58 100644 --- a/src/main/java/org/opensearch/ad/rest/RestSearchADTasksAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestSearchADTasksAction.java @@ -12,10 +12,10 @@ package org.opensearch.ad.rest; import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.transport.SearchADTasksAction; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import com.google.common.collect.ImmutableList; @@ -24,8 +24,8 @@ */ public class RestSearchADTasksAction extends AbstractSearchAction { - private static final String LEGACY_URL_PATH = AnomalyDetectorPlugin.LEGACY_OPENDISTRO_AD_BASE_URI + "/tasks/_search"; - private static final String URL_PATH = AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI + "/tasks/_search"; + private static final String LEGACY_URL_PATH = TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI + "/tasks/_search"; + private static final String URL_PATH = TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI + "/tasks/_search"; private final String SEARCH_ANOMALY_DETECTION_TASKS = "search_anomaly_detection_tasks"; public RestSearchADTasksAction() { diff --git a/src/main/java/org/opensearch/ad/rest/RestSearchADTasksAction.java-e b/src/main/java/org/opensearch/ad/rest/RestSearchADTasksAction.java-e new file mode 100644 index 000000000..6a1bfce58 --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/RestSearchADTasksAction.java-e @@ -0,0 +1,46 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest; + +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.transport.SearchADTasksAction; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to search AD tasks. + */ +public class RestSearchADTasksAction extends AbstractSearchAction { + + private static final String LEGACY_URL_PATH = TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI + "/tasks/_search"; + private static final String URL_PATH = TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI + "/tasks/_search"; + private final String SEARCH_ANOMALY_DETECTION_TASKS = "search_anomaly_detection_tasks"; + + public RestSearchADTasksAction() { + super( + ImmutableList.of(), + ImmutableList.of(Pair.of(URL_PATH, LEGACY_URL_PATH)), + ADCommonName.DETECTION_STATE_INDEX, + ADTask.class, + SearchADTasksAction.INSTANCE + ); + } + + @Override + public String getName() { + return SEARCH_ANOMALY_DETECTION_TASKS; + } + +} diff --git a/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorAction.java index 1e406c4ab..214fa8b2c 100644 --- a/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorAction.java @@ -12,9 +12,9 @@ package org.opensearch.ad.rest; import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.transport.SearchAnomalyDetectorAction; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.constant.CommonName; import com.google.common.collect.ImmutableList; @@ -24,8 +24,8 @@ */ public class RestSearchAnomalyDetectorAction extends AbstractSearchAction { - private static final String LEGACY_URL_PATH = AnomalyDetectorPlugin.LEGACY_OPENDISTRO_AD_BASE_URI + "/_search"; - private static final String URL_PATH = AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI + "/_search"; + private static final String LEGACY_URL_PATH = TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI + "/_search"; + private static final String URL_PATH = TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI + "/_search"; private final String SEARCH_ANOMALY_DETECTOR_ACTION = "search_anomaly_detector"; public RestSearchAnomalyDetectorAction() { diff --git a/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorAction.java-e b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorAction.java-e new file mode 100644 index 000000000..214fa8b2c --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorAction.java-e @@ -0,0 +1,45 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest; + +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.transport.SearchAnomalyDetectorAction; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.constant.CommonName; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to search anomaly detectors. + */ +public class RestSearchAnomalyDetectorAction extends AbstractSearchAction { + + private static final String LEGACY_URL_PATH = TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI + "/_search"; + private static final String URL_PATH = TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI + "/_search"; + private final String SEARCH_ANOMALY_DETECTOR_ACTION = "search_anomaly_detector"; + + public RestSearchAnomalyDetectorAction() { + super( + ImmutableList.of(), + ImmutableList.of(Pair.of(URL_PATH, LEGACY_URL_PATH)), + CommonName.CONFIG_INDEX, + AnomalyDetector.class, + SearchAnomalyDetectorAction.INSTANCE + ); + } + + @Override + public String getName() { + return SEARCH_ANOMALY_DETECTOR_ACTION; + } +} diff --git a/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorInfoAction.java b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorInfoAction.java index 4c7857287..1f2ade113 100644 --- a/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorInfoAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorInfoAction.java @@ -20,7 +20,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.transport.SearchAnomalyDetectorInfoAction; @@ -29,6 +28,7 @@ import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import com.google.common.collect.ImmutableList; @@ -71,16 +71,16 @@ public List replacedRoutes() { // get the count of number of detectors new ReplacedRoute( RestRequest.Method.GET, - String.format(Locale.ROOT, "%s/%s", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, COUNT), + String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, COUNT), RestRequest.Method.GET, - String.format(Locale.ROOT, "%s/%s", AnomalyDetectorPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, COUNT) + String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, COUNT) ), // get if a detector name exists with name new ReplacedRoute( RestRequest.Method.GET, - String.format(Locale.ROOT, "%s/%s", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, MATCH), + String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, MATCH), RestRequest.Method.GET, - String.format(Locale.ROOT, "%s/%s", AnomalyDetectorPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, MATCH) + String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, MATCH) ) ); } diff --git a/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorInfoAction.java-e b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorInfoAction.java-e new file mode 100644 index 000000000..1f2ade113 --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyDetectorInfoAction.java-e @@ -0,0 +1,87 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest; + +import static org.opensearch.timeseries.util.RestHandlerUtils.COUNT; +import static org.opensearch.timeseries.util.RestHandlerUtils.MATCH; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.ad.transport.SearchAnomalyDetectorInfoAction; +import org.opensearch.ad.transport.SearchAnomalyDetectorInfoRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +import com.google.common.collect.ImmutableList; + +public class RestSearchAnomalyDetectorInfoAction extends BaseRestHandler { + + public static final String SEARCH_ANOMALY_DETECTOR_INFO_ACTION = "search_anomaly_detector_info"; + + private static final Logger logger = LogManager.getLogger(RestSearchAnomalyDetectorInfoAction.class); + + public RestSearchAnomalyDetectorInfoAction() {} + + @Override + public String getName() { + return SEARCH_ANOMALY_DETECTOR_INFO_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, org.opensearch.client.node.NodeClient client) throws IOException { + if (!ADEnabledSetting.isADEnabled()) { + throw new IllegalStateException(ADCommonMessages.DISABLED_ERR_MSG); + } + + String detectorName = request.param("name", null); + String rawPath = request.rawPath(); + + SearchAnomalyDetectorInfoRequest searchAnomalyDetectorInfoRequest = new SearchAnomalyDetectorInfoRequest(detectorName, rawPath); + return channel -> client + .execute(SearchAnomalyDetectorInfoAction.INSTANCE, searchAnomalyDetectorInfoRequest, new RestToXContentListener<>(channel)); + } + + @Override + public List routes() { + return ImmutableList.of(); + } + + @Override + public List replacedRoutes() { + return ImmutableList + .of( + // get the count of number of detectors + new ReplacedRoute( + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, COUNT), + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, COUNT) + ), + // get if a detector name exists with name + new ReplacedRoute( + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, MATCH), + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI, MATCH) + ) + ); + } +} diff --git a/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyResultAction.java b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyResultAction.java index e1c299060..9db521595 100644 --- a/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyResultAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyResultAction.java @@ -21,7 +21,6 @@ import org.apache.commons.lang3.tuple.Pair; import org.apache.logging.log4j.util.Strings; import org.opensearch.action.search.SearchRequest; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.settings.ADEnabledSetting; @@ -29,6 +28,7 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.rest.RestRequest; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import com.google.common.collect.ImmutableList; @@ -36,8 +36,8 @@ * This class consists of the REST handler to search anomaly results. */ public class RestSearchAnomalyResultAction extends AbstractSearchAction { - private static final String LEGACY_URL_PATH = AnomalyDetectorPlugin.LEGACY_OPENDISTRO_AD_BASE_URI + "/results/_search"; - private static final String URL_PATH = AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI + "/results/_search"; + private static final String LEGACY_URL_PATH = TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI + "/results/_search"; + private static final String URL_PATH = TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI + "/results/_search"; public static final String SEARCH_ANOMALY_RESULT_ACTION = "search_anomaly_result"; public RestSearchAnomalyResultAction() { diff --git a/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyResultAction.java-e b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyResultAction.java-e new file mode 100644 index 000000000..9db521595 --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/RestSearchAnomalyResultAction.java-e @@ -0,0 +1,87 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest; + +import static org.opensearch.ad.indices.ADIndexManagement.ALL_AD_RESULTS_INDEX_PATTERN; +import static org.opensearch.timeseries.util.RestHandlerUtils.RESULT_INDEX; +import static org.opensearch.timeseries.util.RestHandlerUtils.getSourceContext; + +import java.io.IOException; +import java.util.Locale; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.util.Strings; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.ad.transport.SearchAnomalyResultAction; +import org.opensearch.client.node.NodeClient; +import org.opensearch.rest.RestRequest; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to search anomaly results. + */ +public class RestSearchAnomalyResultAction extends AbstractSearchAction { + private static final String LEGACY_URL_PATH = TimeSeriesAnalyticsPlugin.LEGACY_OPENDISTRO_AD_BASE_URI + "/results/_search"; + private static final String URL_PATH = TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI + "/results/_search"; + public static final String SEARCH_ANOMALY_RESULT_ACTION = "search_anomaly_result"; + + public RestSearchAnomalyResultAction() { + super( + ImmutableList.of(String.format(Locale.ROOT, "%s/{%s}", URL_PATH, RESULT_INDEX)), + ImmutableList.of(Pair.of(URL_PATH, LEGACY_URL_PATH)), + ALL_AD_RESULTS_INDEX_PATTERN, + AnomalyResult.class, + SearchAnomalyResultAction.INSTANCE + ); + } + + @Override + public String getName() { + return SEARCH_ANOMALY_RESULT_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ADEnabledSetting.isADEnabled()) { + throw new IllegalStateException(ADCommonMessages.DISABLED_ERR_MSG); + } + + // resultIndex could be concrete index or index pattern + String resultIndex = Strings.trimToNull(request.param(RESULT_INDEX)); + boolean onlyQueryCustomResultIndex = request.paramAsBoolean("only_query_custom_result_index", false); + if (resultIndex == null && onlyQueryCustomResultIndex) { + throw new IllegalStateException("No custom result index set."); + } + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.parseXContent(request.contentOrSourceParamParser()); + searchSourceBuilder.fetchSource(getSourceContext(request, searchSourceBuilder)); + searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true); + SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(this.index); + + if (resultIndex != null) { + if (onlyQueryCustomResultIndex) { + searchRequest.indices(resultIndex); + } else { + searchRequest.indices(this.index, resultIndex); + } + } + return channel -> client.execute(actionType, searchRequest, search(channel)); + } + +} diff --git a/src/main/java/org/opensearch/ad/rest/RestSearchTopAnomalyResultAction.java b/src/main/java/org/opensearch/ad/rest/RestSearchTopAnomalyResultAction.java index 1bb3e5340..45b8da8a0 100644 --- a/src/main/java/org/opensearch/ad/rest/RestSearchTopAnomalyResultAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestSearchTopAnomalyResultAction.java @@ -11,13 +11,12 @@ package org.opensearch.ad.rest; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; import java.util.List; import java.util.Locale; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.transport.SearchTopAnomalyResultAction; @@ -27,6 +26,7 @@ import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.util.RestHandlerUtils; import com.google.common.collect.ImmutableList; @@ -40,7 +40,7 @@ public class RestSearchTopAnomalyResultAction extends BaseRestHandler { .format( Locale.ROOT, "%s/{%s}/%s/%s", - AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, + TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, RestHandlerUtils.DETECTOR_ID, RestHandlerUtils.RESULTS, RestHandlerUtils.TOP_ANOMALIES diff --git a/src/main/java/org/opensearch/ad/rest/RestSearchTopAnomalyResultAction.java-e b/src/main/java/org/opensearch/ad/rest/RestSearchTopAnomalyResultAction.java-e new file mode 100644 index 000000000..45b8da8a0 --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/RestSearchTopAnomalyResultAction.java-e @@ -0,0 +1,90 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.ad.transport.SearchTopAnomalyResultAction; +import org.opensearch.ad.transport.SearchTopAnomalyResultRequest; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.util.RestHandlerUtils; + +import com.google.common.collect.ImmutableList; + +/** + * The REST handler to search top entity anomaly results for HC detectors. + */ +public class RestSearchTopAnomalyResultAction extends BaseRestHandler { + + private static final String URL_PATH = String + .format( + Locale.ROOT, + "%s/{%s}/%s/%s", + TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, + RestHandlerUtils.DETECTOR_ID, + RestHandlerUtils.RESULTS, + RestHandlerUtils.TOP_ANOMALIES + ); + private final String SEARCH_TOP_ANOMALY_DETECTOR_ACTION = "search_top_anomaly_result"; + + public RestSearchTopAnomalyResultAction() {} + + @Override + public String getName() { + return SEARCH_TOP_ANOMALY_DETECTOR_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + + // Throw error if disabled + if (!ADEnabledSetting.isADEnabled()) { + throw new IllegalStateException(ADCommonMessages.DISABLED_ERR_MSG); + } + + // Get the typed request + SearchTopAnomalyResultRequest searchTopAnomalyResultRequest = getSearchTopAnomalyResultRequest(request); + + return channel -> client + .execute(SearchTopAnomalyResultAction.INSTANCE, searchTopAnomalyResultRequest, new RestToXContentListener<>(channel)); + + } + + private SearchTopAnomalyResultRequest getSearchTopAnomalyResultRequest(RestRequest request) throws IOException { + String detectorId; + if (request.hasParam(RestHandlerUtils.DETECTOR_ID)) { + detectorId = request.param(RestHandlerUtils.DETECTOR_ID); + } else { + throw new IllegalStateException(ADCommonMessages.AD_ID_MISSING_MSG); + } + boolean historical = request.paramAsBoolean("historical", false); + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + return SearchTopAnomalyResultRequest.parse(parser, detectorId, historical); + } + + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.POST, URL_PATH), new Route(RestRequest.Method.GET, URL_PATH)); + } +} diff --git a/src/main/java/org/opensearch/ad/rest/RestStatsAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestStatsAnomalyDetectorAction.java index eb8598670..65b936e98 100644 --- a/src/main/java/org/opensearch/ad/rest/RestStatsAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestStatsAnomalyDetectorAction.java @@ -11,8 +11,8 @@ package org.opensearch.ad.rest; -import static org.opensearch.ad.AnomalyDetectorPlugin.AD_BASE_URI; -import static org.opensearch.ad.AnomalyDetectorPlugin.LEGACY_AD_BASE; +import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_BASE_URI; +import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.LEGACY_AD_BASE; import java.util.Arrays; import java.util.HashSet; diff --git a/src/main/java/org/opensearch/ad/rest/RestStatsAnomalyDetectorAction.java-e b/src/main/java/org/opensearch/ad/rest/RestStatsAnomalyDetectorAction.java-e new file mode 100644 index 000000000..65b936e98 --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/RestStatsAnomalyDetectorAction.java-e @@ -0,0 +1,159 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest; + +import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_BASE_URI; +import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.LEGACY_AD_BASE; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.TreeSet; + +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.ad.stats.ADStats; +import org.opensearch.ad.transport.ADStatsRequest; +import org.opensearch.ad.transport.StatsAnomalyDetectorAction; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.common.Strings; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +import com.google.common.collect.ImmutableList; + +/** + * RestStatsAnomalyDetectorAction consists of the REST handler to get the stats from the anomaly detector plugin. + */ +public class RestStatsAnomalyDetectorAction extends BaseRestHandler { + + private static final String STATS_ANOMALY_DETECTOR_ACTION = "stats_anomaly_detector"; + private ADStats adStats; + private ClusterService clusterService; + private DiscoveryNodeFilterer nodeFilter; + + /** + * Constructor + * + * @param adStats ADStats object + * @param nodeFilter util class to get eligible data nodes + */ + public RestStatsAnomalyDetectorAction(ADStats adStats, DiscoveryNodeFilterer nodeFilter) { + this.adStats = adStats; + this.nodeFilter = nodeFilter; + } + + @Override + public String getName() { + return STATS_ANOMALY_DETECTOR_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { + if (!ADEnabledSetting.isADEnabled()) { + throw new IllegalStateException(ADCommonMessages.DISABLED_ERR_MSG); + } + ADStatsRequest adStatsRequest = getRequest(request); + return channel -> client.execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a ADStatsRequest from a RestRequest + * + * @param request RestRequest + * @return ADStatsRequest Request containing stats to be retrieved + */ + private ADStatsRequest getRequest(RestRequest request) { + // parse the nodes the user wants to query the stats for + String nodesIdsStr = request.param("nodeId"); + Set validStats = adStats.getStats().keySet(); + + ADStatsRequest adStatsRequest = null; + if (!Strings.isEmpty(nodesIdsStr)) { + String[] nodeIdsArr = nodesIdsStr.split(","); + adStatsRequest = new ADStatsRequest(nodeIdsArr); + } else { + DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); + adStatsRequest = new ADStatsRequest(dataNodes); + } + + adStatsRequest.timeout(request.param("timeout")); + + // parse the stats the user wants to see + HashSet statsSet = null; + String statsStr = request.param("stat"); + if (!Strings.isEmpty(statsStr)) { + statsSet = new HashSet<>(Arrays.asList(statsStr.split(","))); + } + + if (statsSet == null) { + adStatsRequest.addAll(validStats); // retrieve all stats if none are specified + } else if (statsSet.size() == 1 && statsSet.contains(ADStatsRequest.ALL_STATS_KEY)) { + adStatsRequest.addAll(validStats); + } else if (statsSet.contains(ADStatsRequest.ALL_STATS_KEY)) { + throw new IllegalArgumentException( + "Request " + request.path() + " contains " + ADStatsRequest.ALL_STATS_KEY + " and individual stats" + ); + } else { + Set invalidStats = new TreeSet<>(); + for (String stat : statsSet) { + if (validStats.contains(stat)) { + adStatsRequest.addStat(stat); + } else { + invalidStats.add(stat); + } + } + + if (!invalidStats.isEmpty()) { + throw new IllegalArgumentException(unrecognized(request, invalidStats, adStatsRequest.getStatsToBeRetrieved(), "stat")); + } + } + return adStatsRequest; + } + + @Override + public List routes() { + return ImmutableList.of(); + } + + @Override + public List replacedRoutes() { + return ImmutableList + .of( + // delete anomaly detector document + new ReplacedRoute( + RestRequest.Method.GET, + AD_BASE_URI + "/{nodeId}/stats/", + RestRequest.Method.GET, + LEGACY_AD_BASE + "/{nodeId}/stats/" + ), + new ReplacedRoute( + RestRequest.Method.GET, + AD_BASE_URI + "/{nodeId}/stats/{stat}", + RestRequest.Method.GET, + LEGACY_AD_BASE + "/{nodeId}/stats/{stat}" + ), + new ReplacedRoute(RestRequest.Method.GET, AD_BASE_URI + "/stats/", RestRequest.Method.GET, LEGACY_AD_BASE + "/stats/"), + new ReplacedRoute( + RestRequest.Method.GET, + AD_BASE_URI + "/stats/{stat}", + RestRequest.Method.GET, + LEGACY_AD_BASE + "/stats/{stat}" + ) + ); + } +} diff --git a/src/main/java/org/opensearch/ad/rest/RestValidateAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestValidateAnomalyDetectorAction.java index 9f6808e18..e728889f8 100644 --- a/src/main/java/org/opensearch/ad/rest/RestValidateAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestValidateAnomalyDetectorAction.java @@ -11,7 +11,7 @@ package org.opensearch.ad.rest; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.timeseries.util.RestHandlerUtils.TYPE; import static org.opensearch.timeseries.util.RestHandlerUtils.VALIDATE; @@ -25,7 +25,6 @@ import java.util.stream.Collectors; import org.apache.commons.lang3.StringUtils; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.DetectorValidationIssue; @@ -36,13 +35,14 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; -import org.opensearch.rest.RestStatus; import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.common.exception.ValidationException; import org.opensearch.timeseries.model.ValidationAspect; @@ -75,11 +75,11 @@ public List routes() { .of( new Route( RestRequest.Method.POST, - String.format(Locale.ROOT, "%s/%s", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, VALIDATE) + String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, VALIDATE) ), new Route( RestRequest.Method.POST, - String.format(Locale.ROOT, "%s/%s/{%s}", AnomalyDetectorPlugin.AD_BASE_DETECTORS_URI, VALIDATE, TYPE) + String.format(Locale.ROOT, "%s/%s/{%s}", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, VALIDATE, TYPE) ) ); } diff --git a/src/main/java/org/opensearch/ad/rest/RestValidateAnomalyDetectorAction.java-e b/src/main/java/org/opensearch/ad/rest/RestValidateAnomalyDetectorAction.java-e new file mode 100644 index 000000000..9828ae9ee --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/RestValidateAnomalyDetectorAction.java-e @@ -0,0 +1,149 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.util.RestHandlerUtils.TYPE; +import static org.opensearch.timeseries.util.RestHandlerUtils.VALIDATE; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.DetectorValidationIssue; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.ad.transport.ValidateAnomalyDetectorAction; +import org.opensearch.ad.transport.ValidateAnomalyDetectorRequest; +import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestRequest; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.model.ValidationAspect; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to validate anomaly detector configurations. + */ +public class RestValidateAnomalyDetectorAction extends AbstractAnomalyDetectorAction { + private static final String VALIDATE_ANOMALY_DETECTOR_ACTION = "validate_anomaly_detector_action"; + + public static final Set ALL_VALIDATION_ASPECTS_STRS = Arrays + .asList(ValidationAspect.values()) + .stream() + .map(aspect -> aspect.getName()) + .collect(Collectors.toSet()); + + public RestValidateAnomalyDetectorAction(Settings settings, ClusterService clusterService) { + super(settings, clusterService); + } + + @Override + public String getName() { + return VALIDATE_ANOMALY_DETECTOR_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, VALIDATE) + ), + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/%s/{%s}", TimeSeriesAnalyticsPlugin.AD_BASE_DETECTORS_URI, VALIDATE, TYPE) + ) + ); + } + + protected void sendAnomalyDetectorValidationParseResponse(DetectorValidationIssue issue, RestChannel channel) throws IOException { + try { + BytesRestResponse restResponse = new BytesRestResponse( + RestStatus.OK, + new ValidateAnomalyDetectorResponse(issue).toXContent(channel.newBuilder()) + ); + channel.sendResponse(restResponse); + } catch (Exception e) { + channel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, e.getMessage())); + } + } + + private Boolean validationTypesAreAccepted(String validationType) { + Set typesInRequest = new HashSet<>(Arrays.asList(validationType.split(","))); + return (!Collections.disjoint(typesInRequest, ALL_VALIDATION_ASPECTS_STRS)); + } + + @Override + protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ADEnabledSetting.isADEnabled()) { + throw new IllegalStateException(ADCommonMessages.DISABLED_ERR_MSG); + } + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + String typesStr = request.param(TYPE); + + // if type param isn't blank and isn't a part of possible validation types throws exception + if (!StringUtils.isBlank(typesStr)) { + if (!validationTypesAreAccepted(typesStr)) { + throw new IllegalStateException(ADCommonMessages.NOT_EXISTENT_VALIDATION_TYPE); + } + } + + return channel -> { + AnomalyDetector detector; + try { + detector = AnomalyDetector.parse(parser); + } catch (Exception ex) { + if (ex instanceof ValidationException) { + ValidationException ADException = (ValidationException) ex; + DetectorValidationIssue issue = new DetectorValidationIssue( + ADException.getAspect(), + ADException.getType(), + ADException.getMessage() + ); + sendAnomalyDetectorValidationParseResponse(issue, channel); + return; + } else { + throw ex; + } + } + ValidateAnomalyDetectorRequest validateAnomalyDetectorRequest = new ValidateAnomalyDetectorRequest( + detector, + typesStr, + maxSingleEntityDetectors, + maxMultiEntityDetectors, + maxAnomalyFeatures, + requestTimeout + ); + client.execute(ValidateAnomalyDetectorAction.INSTANCE, validateAnomalyDetectorRequest, new RestToXContentListener<>(channel)); + }; + } +} diff --git a/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java index aa0b1e787..82f07b497 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java @@ -12,7 +12,7 @@ package org.opensearch.ad.rest.handler; import static org.opensearch.ad.model.ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; import static org.opensearch.timeseries.util.ParseUtils.listEqualsWithoutConsideringOrder; import static org.opensearch.timeseries.util.ParseUtils.parseAggregators; @@ -70,13 +70,13 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.authuser.User; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.rest.RestRequest; -import org.opensearch.rest.RestStatus; import org.opensearch.search.aggregations.AggregatorFactories; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.timeseries.common.exception.ValidationException; @@ -584,10 +584,10 @@ protected void validateCategoricalField(String detectorId, boolean indexingDryRu ActionListener mappingsListener = ActionListener.wrap(getMappingsResponse -> { // example getMappingsResponse: // GetFieldMappingsResponse{mappings={server-metrics={_doc={service=FieldMappingMetadata{fullName='service', - // source=org.opensearch.common.bytes.BytesArray@7ba87dbd}}}}} + // source=org.opensearch.core.common.bytes.BytesArray@7ba87dbd}}}}} // for nested field, it would be // GetFieldMappingsResponse{mappings={server-metrics={_doc={host_nest.host2=FieldMappingMetadata{fullName='host_nest.host2', - // source=org.opensearch.common.bytes.BytesArray@8fb4de08}}}}} + // source=org.opensearch.core.common.bytes.BytesArray@8fb4de08}}}}} boolean foundField = false; // Review why the change from FieldMappingMetadata to GetFieldMappingsResponse.FieldMappingMetadata @@ -597,7 +597,7 @@ protected void validateCategoricalField(String detectorId, boolean indexingDryRu for (Map.Entry field2Metadata : mappingsByField.entrySet()) { // example output: // host_nest.host2=FieldMappingMetadata{fullName='host_nest.host2', - // source=org.opensearch.common.bytes.BytesArray@8fb4de08} + // source=org.opensearch.core.common.bytes.BytesArray@8fb4de08} // Review why the change from FieldMappingMetadata to GetFieldMappingsResponse.FieldMappingMetadata diff --git a/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java-e b/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java-e new file mode 100644 index 000000000..23ce7d2ca --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java-e @@ -0,0 +1,946 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest.handler; + +import static org.opensearch.ad.model.ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; +import static org.opensearch.timeseries.util.ParseUtils.listEqualsWithoutConsideringOrder; +import static org.opensearch.timeseries.util.ParseUtils.parseAggregators; +import static org.opensearch.timeseries.util.RestHandlerUtils.XCONTENT_WITH_TYPE; +import static org.opensearch.timeseries.util.RestHandlerUtils.isExceptionCausedByInvalidQuery; + +import java.io.IOException; +import java.time.Clock; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.commons.lang.StringUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionResponse; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsAction; +import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsRequest; +import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.support.replication.ReplicationResponse; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.MergeableList; +import org.opensearch.ad.rest.RestValidateAnomalyDetectorAction; +import org.opensearch.ad.settings.ADNumericSetting; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.IndexAnomalyDetectorResponse; +import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; +import org.opensearch.ad.util.MultiResponsesDelegateActionListener; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.rest.RestRequest; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.Sets; + +/** + * Abstract Anomaly detector REST action handler to process POST/PUT request. + * POST request is for either validating or creating anomaly detector. + * PUT request is for updating anomaly detector. + * + *

Create, Update and Validate APIs all share similar validation process, the differences in logic + * between the three usages of this class are outlined below.

+ *
    + *
  • Create/Update:

    This class is extended by IndexAnomalyDetectorActionHandler which handles + * either create AD or update AD REST Actions. When this class is constructed from these + * actions then the isDryRun parameter will be instantiated as false.

    + *

    This means that if the AD index doesn't exist at the time request is received it will be created. + * Furthermore, this handler will actually create or update the AD and also handle a few exceptions as + * they are thrown instead of converting some of them to ADValidationExceptions.

    + *
  • Validate:

    This class is also extended by ValidateAnomalyDetectorActionHandler which handles + * the validate AD REST Actions. When this class is constructed from these + * actions then the isDryRun parameter will be instantiated as true.

    + *

    This means that if the AD index doesn't exist at the time request is received it wont be created. + * Furthermore, this means that the AD won't actually be created and all exceptions will be wrapped into + * DetectorValidationResponses hence the user will be notified which validation checks didn't pass.

    + *

    After completing all the first round of validation which is identical to the checks that are done for the + * create/update APIs, this code will check if the validation type is 'model' and if true it will + * instantiate the ModelValidationActionHandler class and run the non-blocker validation logic

    + *
+ */ +public abstract class AbstractAnomalyDetectorActionHandler { + public static final String EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG = "Can't create more than %d multi-entity anomaly detectors."; + public static final String EXCEEDED_MAX_SINGLE_ENTITY_DETECTORS_PREFIX_MSG = + "Can't create more than %d single-entity anomaly detectors."; + public static final String NO_DOCS_IN_USER_INDEX_MSG = "Can't create anomaly detector as no document is found in the indices: "; + public static final String ONLY_ONE_CATEGORICAL_FIELD_ERR_MSG = "We can have only one categorical field."; + public static final String CATEGORICAL_FIELD_TYPE_ERR_MSG = "A categorical field must be of type keyword or ip."; + public static final String CATEGORY_NOT_FOUND_ERR_MSG = "Can't find the categorical field %s"; + public static final String DUPLICATE_DETECTOR_MSG = "Cannot create anomaly detector with name [%s] as it's already used by detector %s"; + public static final String NAME_REGEX = "[a-zA-Z0-9._-]+"; + public static final Integer MAX_DETECTOR_NAME_SIZE = 64; + private static final Set DEFAULT_VALIDATION_ASPECTS = Sets.newHashSet(ValidationAspect.DETECTOR); + + public static String INVALID_NAME_SIZE = "Name should be shortened. The maximum limit is " + MAX_DETECTOR_NAME_SIZE + " characters."; + + protected final ADIndexManagement anomalyDetectionIndices; + protected final String detectorId; + protected final Long seqNo; + protected final Long primaryTerm; + protected final WriteRequest.RefreshPolicy refreshPolicy; + protected final AnomalyDetector anomalyDetector; + protected final ClusterService clusterService; + + protected final Logger logger = LogManager.getLogger(AbstractAnomalyDetectorActionHandler.class); + protected final TimeValue requestTimeout; + protected final Integer maxSingleEntityAnomalyDetectors; + protected final Integer maxMultiEntityAnomalyDetectors; + protected final Integer maxAnomalyFeatures; + protected final AnomalyDetectorActionHandler handler = new AnomalyDetectorActionHandler(); + protected final RestRequest.Method method; + protected final Client client; + protected final SecurityClientUtil clientUtil; + protected final TransportService transportService; + protected final NamedXContentRegistry xContentRegistry; + protected final ActionListener listener; + protected final User user; + protected final ADTaskManager adTaskManager; + protected final SearchFeatureDao searchFeatureDao; + protected final boolean isDryRun; + protected final Clock clock; + protected final String validationType; + protected final Settings settings; + + /** + * Constructor function. + * + * @param clusterService ClusterService + * @param client ES node client that executes actions on the local node + * @param clientUtil AD security client + * @param transportService ES transport service + * @param listener ES channel used to construct bytes / builder based outputs, and send responses + * @param anomalyDetectionIndices anomaly detector index manager + * @param detectorId detector identifier + * @param seqNo sequence number of last modification + * @param primaryTerm primary term of last modification + * @param refreshPolicy refresh policy + * @param anomalyDetector anomaly detector instance + * @param requestTimeout request time out configuration + * @param maxSingleEntityAnomalyDetectors max single-entity anomaly detectors allowed + * @param maxMultiEntityAnomalyDetectors max multi-entity detectors allowed + * @param maxAnomalyFeatures max features allowed per detector + * @param method Rest Method type + * @param xContentRegistry Registry which is used for XContentParser + * @param user User context + * @param adTaskManager AD Task manager + * @param searchFeatureDao Search feature dao + * @param isDryRun Whether handler is dryrun or not + * @param validationType Whether validation is for detector or model + * @param clock clock object to know when to timeout + * @param settings Node settings + */ + public AbstractAnomalyDetectorActionHandler( + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + TransportService transportService, + ActionListener listener, + ADIndexManagement anomalyDetectionIndices, + String detectorId, + Long seqNo, + Long primaryTerm, + WriteRequest.RefreshPolicy refreshPolicy, + AnomalyDetector anomalyDetector, + TimeValue requestTimeout, + Integer maxSingleEntityAnomalyDetectors, + Integer maxMultiEntityAnomalyDetectors, + Integer maxAnomalyFeatures, + RestRequest.Method method, + NamedXContentRegistry xContentRegistry, + User user, + ADTaskManager adTaskManager, + SearchFeatureDao searchFeatureDao, + String validationType, + boolean isDryRun, + Clock clock, + Settings settings + ) { + this.clusterService = clusterService; + this.client = client; + this.clientUtil = clientUtil; + this.transportService = transportService; + this.anomalyDetectionIndices = anomalyDetectionIndices; + this.listener = listener; + this.detectorId = detectorId; + this.seqNo = seqNo; + this.primaryTerm = primaryTerm; + this.refreshPolicy = refreshPolicy; + this.anomalyDetector = anomalyDetector; + this.requestTimeout = requestTimeout; + this.maxSingleEntityAnomalyDetectors = maxSingleEntityAnomalyDetectors; + this.maxMultiEntityAnomalyDetectors = maxMultiEntityAnomalyDetectors; + this.maxAnomalyFeatures = maxAnomalyFeatures; + this.method = method; + this.xContentRegistry = xContentRegistry; + this.user = user; + this.adTaskManager = adTaskManager; + this.searchFeatureDao = searchFeatureDao; + this.validationType = validationType; + this.isDryRun = isDryRun; + this.clock = clock; + this.settings = settings; + } + + /** + * Start function to process create/update/validate anomaly detector request. + * If detector is not using custom result index, check if anomaly detector + * index exist first, if not, will create first. Otherwise, check if custom + * result index exists or not. If exists, will check if index mapping matches + * AD result index mapping and if user has correct permission to write index. + * If doesn't exist, will create custom result index with AD result index + * mapping. + */ + public void start() { + String resultIndex = anomalyDetector.getCustomResultIndex(); + // use default detector result index which is system index + if (resultIndex == null) { + createOrUpdateDetector(); + return; + } + + if (this.isDryRun) { + if (anomalyDetectionIndices.doesIndexExist(resultIndex)) { + anomalyDetectionIndices + .validateCustomResultIndexAndExecute( + resultIndex, + () -> createOrUpdateDetector(), + ActionListener.wrap(r -> createOrUpdateDetector(), ex -> { + logger.error(ex); + listener + .onFailure( + new ValidationException(ex.getMessage(), ValidationIssueType.RESULT_INDEX, ValidationAspect.DETECTOR) + ); + return; + }) + ); + return; + } else { + createOrUpdateDetector(); + return; + } + } + // use custom result index if not validating and resultIndex not null + anomalyDetectionIndices.initCustomResultIndexAndExecute(resultIndex, () -> createOrUpdateDetector(), listener); + } + + // if isDryRun is true then this method is being executed through Validation API meaning actual + // index won't be created, only validation checks will be executed throughout the class + private void createOrUpdateDetector() { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + if (!anomalyDetectionIndices.doesConfigIndexExist() && !this.isDryRun) { + logger.info("AnomalyDetector Indices do not exist"); + anomalyDetectionIndices + .initConfigIndex( + ActionListener + .wrap(response -> onCreateMappingsResponse(response, false), exception -> listener.onFailure(exception)) + ); + } else { + logger.info("AnomalyDetector Indices do exist, calling prepareAnomalyDetectorIndexing"); + logger.info("DryRun variable " + this.isDryRun); + validateDetectorName(this.isDryRun); + } + } catch (Exception e) { + logger.error("Failed to create or update detector " + detectorId, e); + listener.onFailure(e); + } + } + + // These validation checks are executed here and not in AnomalyDetector.parse() + // in order to not break any past detectors that were made with invalid names + // because it was never check on the backend in the past + protected void validateDetectorName(boolean indexingDryRun) { + if (!anomalyDetector.getName().matches(NAME_REGEX)) { + listener.onFailure(new ValidationException(CommonMessages.INVALID_NAME, ValidationIssueType.NAME, ValidationAspect.DETECTOR)); + return; + + } + if (anomalyDetector.getName().length() > MAX_DETECTOR_NAME_SIZE) { + listener.onFailure(new ValidationException(INVALID_NAME_SIZE, ValidationIssueType.NAME, ValidationAspect.DETECTOR)); + return; + } + validateTimeField(indexingDryRun); + } + + protected void validateTimeField(boolean indexingDryRun) { + String givenTimeField = anomalyDetector.getTimeField(); + GetFieldMappingsRequest getMappingsRequest = new GetFieldMappingsRequest(); + getMappingsRequest.indices(anomalyDetector.getIndices().toArray(new String[0])).fields(givenTimeField); + getMappingsRequest.indicesOptions(IndicesOptions.strictExpand()); + + // comments explaining fieldMappingResponse parsing can be found inside following method: + // AbstractAnomalyDetectorActionHandler.validateCategoricalField(String, boolean) + ActionListener mappingsListener = ActionListener.wrap(getMappingsResponse -> { + boolean foundField = false; + Map> mappingsByIndex = getMappingsResponse.mappings(); + + for (Map mappingsByField : mappingsByIndex.values()) { + for (Map.Entry field2Metadata : mappingsByField.entrySet()) { + + GetFieldMappingsResponse.FieldMappingMetadata fieldMetadata = field2Metadata.getValue(); + if (fieldMetadata != null) { + // sourceAsMap returns sth like {host2={type=keyword}} with host2 being a nested field + Map fieldMap = fieldMetadata.sourceAsMap(); + if (fieldMap != null) { + for (Object type : fieldMap.values()) { + if (type instanceof Map) { + foundField = true; + Map metadataMap = (Map) type; + String typeName = (String) metadataMap.get(CommonName.TYPE); + if (!typeName.equals(CommonName.DATE_TYPE)) { + listener + .onFailure( + new ValidationException( + String.format(Locale.ROOT, CommonMessages.INVALID_TIMESTAMP, givenTimeField), + ValidationIssueType.TIMEFIELD_FIELD, + ValidationAspect.DETECTOR + ) + ); + return; + } + } + } + } + } + } + } + if (!foundField) { + listener + .onFailure( + new ValidationException( + String.format(Locale.ROOT, CommonMessages.NON_EXISTENT_TIMESTAMP, givenTimeField), + ValidationIssueType.TIMEFIELD_FIELD, + ValidationAspect.DETECTOR + ) + ); + return; + } + prepareAnomalyDetectorIndexing(indexingDryRun); + }, error -> { + String message = String.format(Locale.ROOT, "Fail to get the index mapping of %s", anomalyDetector.getIndices()); + logger.error(message, error); + listener.onFailure(new IllegalArgumentException(message)); + }); + clientUtil.executeWithInjectedSecurity(GetFieldMappingsAction.INSTANCE, getMappingsRequest, user, client, mappingsListener); + } + + /** + * Prepare for indexing a new anomaly detector. + * @param indexingDryRun if this is dryrun for indexing; when validation, it is true; when create/update, it is false + */ + protected void prepareAnomalyDetectorIndexing(boolean indexingDryRun) { + if (method == RestRequest.Method.PUT) { + handler + .getDetectorJob( + clusterService, + client, + detectorId, + listener, + () -> updateAnomalyDetector(detectorId, indexingDryRun), + xContentRegistry + ); + } else { + createAnomalyDetector(indexingDryRun); + } + } + + protected void updateAnomalyDetector(String detectorId, boolean indexingDryRun) { + GetRequest request = new GetRequest(CommonName.CONFIG_INDEX, detectorId); + client + .get( + request, + ActionListener + .wrap( + response -> onGetAnomalyDetectorResponse(response, indexingDryRun, detectorId), + exception -> listener.onFailure(exception) + ) + ); + } + + private void onGetAnomalyDetectorResponse(GetResponse response, boolean indexingDryRun, String detectorId) { + if (!response.isExists()) { + listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, RestStatus.NOT_FOUND)); + return; + } + try (XContentParser parser = RestHandlerUtils.createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + AnomalyDetector existingDetector = AnomalyDetector.parse(parser, response.getId(), response.getVersion()); + // If detector category field changed, frontend may not be able to render AD result for different detector types correctly. + // For example, if detector changed from HC to single entity detector, AD result page may show multiple anomaly + // result points on the same time point if there are multiple entities have anomaly results. + // If single-category HC changed category field from IP to error type, the AD result page may show both IP and error type + // in top N entities list. That's confusing. + // So we decide to block updating detector category field. + if (!listEqualsWithoutConsideringOrder(existingDetector.getCategoryFields(), anomalyDetector.getCategoryFields())) { + listener.onFailure(new OpenSearchStatusException(CommonMessages.CAN_NOT_CHANGE_CATEGORY_FIELD, RestStatus.BAD_REQUEST)); + return; + } + if (!Objects.equals(existingDetector.getCustomResultIndex(), anomalyDetector.getCustomResultIndex())) { + listener + .onFailure(new OpenSearchStatusException(CommonMessages.CAN_NOT_CHANGE_CUSTOM_RESULT_INDEX, RestStatus.BAD_REQUEST)); + return; + } + + adTaskManager.getAndExecuteOnLatestDetectorLevelTask(detectorId, HISTORICAL_DETECTOR_TASK_TYPES, (adTask) -> { + if (adTask.isPresent() && !adTask.get().isDone()) { + // can't update detector if there is AD task running + listener.onFailure(new OpenSearchStatusException("Detector is running", RestStatus.INTERNAL_SERVER_ERROR)); + } else { + validateExistingDetector(existingDetector, indexingDryRun); + } + }, transportService, true, listener); + } catch (IOException e) { + String message = "Failed to parse anomaly detector " + detectorId; + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + + } + + protected void validateExistingDetector(AnomalyDetector existingDetector, boolean indexingDryRun) { + if (!hasCategoryField(existingDetector) && hasCategoryField(this.anomalyDetector)) { + validateAgainstExistingMultiEntityAnomalyDetector(detectorId, indexingDryRun); + } else { + validateCategoricalField(detectorId, indexingDryRun); + } + } + + protected boolean hasCategoryField(AnomalyDetector detector) { + return detector.getCategoryFields() != null && !detector.getCategoryFields().isEmpty(); + } + + protected void validateAgainstExistingMultiEntityAnomalyDetector(String detectorId, boolean indexingDryRun) { + if (anomalyDetectionIndices.doesConfigIndexExist()) { + QueryBuilder query = QueryBuilders.boolQuery().filter(QueryBuilders.existsQuery(AnomalyDetector.CATEGORY_FIELD)); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout); + + SearchRequest searchRequest = new SearchRequest(CommonName.CONFIG_INDEX).source(searchSourceBuilder); + client + .search( + searchRequest, + ActionListener + .wrap( + response -> onSearchMultiEntityAdResponse(response, detectorId, indexingDryRun), + exception -> listener.onFailure(exception) + ) + ); + } else { + validateCategoricalField(detectorId, indexingDryRun); + } + + } + + protected void createAnomalyDetector(boolean indexingDryRun) { + try { + List categoricalFields = anomalyDetector.getCategoryFields(); + if (categoricalFields != null && categoricalFields.size() > 0) { + validateAgainstExistingMultiEntityAnomalyDetector(null, indexingDryRun); + } else { + if (anomalyDetectionIndices.doesConfigIndexExist()) { + QueryBuilder query = QueryBuilders.matchAllQuery(); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout); + + SearchRequest searchRequest = new SearchRequest(CommonName.CONFIG_INDEX).source(searchSourceBuilder); + + client + .search( + searchRequest, + ActionListener + .wrap( + response -> onSearchSingleEntityAdResponse(response, indexingDryRun), + exception -> listener.onFailure(exception) + ) + ); + } else { + searchAdInputIndices(null, indexingDryRun); + } + + } + } catch (Exception e) { + listener.onFailure(e); + } + } + + protected void onSearchSingleEntityAdResponse(SearchResponse response, boolean indexingDryRun) throws IOException { + if (response.getHits().getTotalHits().value >= maxSingleEntityAnomalyDetectors) { + String errorMsgSingleEntity = String + .format(Locale.ROOT, EXCEEDED_MAX_SINGLE_ENTITY_DETECTORS_PREFIX_MSG, maxSingleEntityAnomalyDetectors); + logger.error(errorMsgSingleEntity); + if (indexingDryRun) { + listener + .onFailure( + new ValidationException(errorMsgSingleEntity, ValidationIssueType.GENERAL_SETTINGS, ValidationAspect.DETECTOR) + ); + return; + } + listener.onFailure(new IllegalArgumentException(errorMsgSingleEntity)); + } else { + searchAdInputIndices(null, indexingDryRun); + } + } + + protected void onSearchMultiEntityAdResponse(SearchResponse response, String detectorId, boolean indexingDryRun) throws IOException { + if (response.getHits().getTotalHits().value >= maxMultiEntityAnomalyDetectors) { + String errorMsg = String.format(Locale.ROOT, EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG, maxMultiEntityAnomalyDetectors); + logger.error(errorMsg); + if (indexingDryRun) { + listener.onFailure(new ValidationException(errorMsg, ValidationIssueType.GENERAL_SETTINGS, ValidationAspect.DETECTOR)); + return; + } + listener.onFailure(new IllegalArgumentException(errorMsg)); + } else { + validateCategoricalField(detectorId, indexingDryRun); + } + } + + @SuppressWarnings("unchecked") + protected void validateCategoricalField(String detectorId, boolean indexingDryRun) { + List categoryField = anomalyDetector.getCategoryFields(); + + if (categoryField == null) { + searchAdInputIndices(detectorId, indexingDryRun); + return; + } + + // we only support a certain number of categorical field + // If there is more fields than required, AnomalyDetector's constructor + // throws ADValidationException before reaching this line + int maxCategoryFields = ADNumericSetting.maxCategoricalFields(); + if (categoryField.size() > maxCategoryFields) { + listener + .onFailure( + new ValidationException( + CommonMessages.getTooManyCategoricalFieldErr(maxCategoryFields), + ValidationIssueType.CATEGORY, + ValidationAspect.DETECTOR + ) + ); + return; + } + + String categoryField0 = categoryField.get(0); + + GetFieldMappingsRequest getMappingsRequest = new GetFieldMappingsRequest(); + getMappingsRequest.indices(anomalyDetector.getIndices().toArray(new String[0])).fields(categoryField.toArray(new String[0])); + getMappingsRequest.indicesOptions(IndicesOptions.strictExpand()); + + ActionListener mappingsListener = ActionListener.wrap(getMappingsResponse -> { + // example getMappingsResponse: + // GetFieldMappingsResponse{mappings={server-metrics={_doc={service=FieldMappingMetadata{fullName='service', + // source=org.opensearch.core.common.bytes.BytesArray@7ba87dbd}}}}} + // for nested field, it would be + // GetFieldMappingsResponse{mappings={server-metrics={_doc={host_nest.host2=FieldMappingMetadata{fullName='host_nest.host2', + // source=org.opensearch.core.common.bytes.BytesArray@8fb4de08}}}}} + boolean foundField = false; + + // Review why the change from FieldMappingMetadata to GetFieldMappingsResponse.FieldMappingMetadata + Map> mappingsByIndex = getMappingsResponse.mappings(); + + for (Map mappingsByField : mappingsByIndex.values()) { + for (Map.Entry field2Metadata : mappingsByField.entrySet()) { + // example output: + // host_nest.host2=FieldMappingMetadata{fullName='host_nest.host2', + // source=org.opensearch.core.common.bytes.BytesArray@8fb4de08} + + // Review why the change from FieldMappingMetadata to GetFieldMappingsResponse.FieldMappingMetadata + + GetFieldMappingsResponse.FieldMappingMetadata fieldMetadata = field2Metadata.getValue(); + + if (fieldMetadata != null) { + // sourceAsMap returns sth like {host2={type=keyword}} with host2 being a nested field + Map fieldMap = fieldMetadata.sourceAsMap(); + if (fieldMap != null) { + for (Object type : fieldMap.values()) { + if (type != null && type instanceof Map) { + foundField = true; + Map metadataMap = (Map) type; + String typeName = (String) metadataMap.get(CommonName.TYPE); + if (!typeName.equals(CommonName.KEYWORD_TYPE) && !typeName.equals(CommonName.IP_TYPE)) { + listener + .onFailure( + new ValidationException( + CATEGORICAL_FIELD_TYPE_ERR_MSG, + ValidationIssueType.CATEGORY, + ValidationAspect.DETECTOR + ) + ); + return; + } + } + } + } + + } + } + } + + if (foundField == false) { + listener + .onFailure( + new ValidationException( + String.format(Locale.ROOT, CATEGORY_NOT_FOUND_ERR_MSG, categoryField0), + ValidationIssueType.CATEGORY, + ValidationAspect.DETECTOR + ) + ); + return; + } + + searchAdInputIndices(detectorId, indexingDryRun); + }, error -> { + String message = String.format(Locale.ROOT, "Fail to get the index mapping of %s", anomalyDetector.getIndices()); + logger.error(message, error); + listener.onFailure(new IllegalArgumentException(message)); + }); + + clientUtil.executeWithInjectedSecurity(GetFieldMappingsAction.INSTANCE, getMappingsRequest, user, client, mappingsListener); + } + + protected void searchAdInputIndices(String detectorId, boolean indexingDryRun) { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(QueryBuilders.matchAllQuery()) + .size(0) + .timeout(requestTimeout); + + SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])).source(searchSourceBuilder); + + ActionListener searchResponseListener = ActionListener + .wrap( + searchResponse -> onSearchAdInputIndicesResponse(searchResponse, detectorId, indexingDryRun), + exception -> listener.onFailure(exception) + ); + + clientUtil.asyncRequestWithInjectedSecurity(searchRequest, client::search, user, client, searchResponseListener); + } + + protected void onSearchAdInputIndicesResponse(SearchResponse response, String detectorId, boolean indexingDryRun) throws IOException { + if (response.getHits().getTotalHits().value == 0) { + String errorMsg = NO_DOCS_IN_USER_INDEX_MSG + Arrays.toString(anomalyDetector.getIndices().toArray(new String[0])); + logger.error(errorMsg); + if (indexingDryRun) { + listener.onFailure(new ValidationException(errorMsg, ValidationIssueType.INDICES, ValidationAspect.DETECTOR)); + return; + } + listener.onFailure(new IllegalArgumentException(errorMsg)); + } else { + validateAnomalyDetectorFeatures(detectorId, indexingDryRun); + } + } + + protected void checkADNameExists(String detectorId, boolean indexingDryRun) throws IOException { + if (anomalyDetectionIndices.doesConfigIndexExist()) { + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + // src/main/resources/mappings/anomaly-detectors.json#L14 + boolQueryBuilder.must(QueryBuilders.termQuery("name.keyword", anomalyDetector.getName())); + if (StringUtils.isNotBlank(detectorId)) { + boolQueryBuilder.mustNot(QueryBuilders.termQuery(RestHandlerUtils._ID, detectorId)); + } + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(boolQueryBuilder).timeout(requestTimeout); + SearchRequest searchRequest = new SearchRequest(CommonName.CONFIG_INDEX).source(searchSourceBuilder); + client + .search( + searchRequest, + ActionListener + .wrap( + searchResponse -> onSearchADNameResponse(searchResponse, detectorId, anomalyDetector.getName(), indexingDryRun), + exception -> listener.onFailure(exception) + ) + ); + } else { + tryIndexingAnomalyDetector(indexingDryRun); + } + + } + + protected void onSearchADNameResponse(SearchResponse response, String detectorId, String name, boolean indexingDryRun) + throws IOException { + if (response.getHits().getTotalHits().value > 0) { + String errorMsg = String + .format( + Locale.ROOT, + DUPLICATE_DETECTOR_MSG, + name, + Arrays.stream(response.getHits().getHits()).map(hit -> hit.getId()).collect(Collectors.toList()) + ); + logger.warn(errorMsg); + listener.onFailure(new ValidationException(errorMsg, ValidationIssueType.NAME, ValidationAspect.DETECTOR)); + } else { + tryIndexingAnomalyDetector(indexingDryRun); + } + } + + protected void tryIndexingAnomalyDetector(boolean indexingDryRun) throws IOException { + if (!indexingDryRun) { + indexAnomalyDetector(detectorId); + } else { + finishDetectorValidationOrContinueToModelValidation(); + } + } + + protected Set getValidationTypes(String validationType) { + if (StringUtils.isBlank(validationType)) { + return DEFAULT_VALIDATION_ASPECTS; + } else { + Set typesInRequest = new HashSet<>(Arrays.asList(validationType.split(","))); + return ValidationAspect + .getNames(Sets.intersection(RestValidateAnomalyDetectorAction.ALL_VALIDATION_ASPECTS_STRS, typesInRequest)); + } + } + + protected void finishDetectorValidationOrContinueToModelValidation() { + logger.info("Skipping indexing detector. No blocking issue found so far."); + if (!getValidationTypes(validationType).contains(ValidationAspect.MODEL)) { + listener.onResponse(null); + } else { + ModelValidationActionHandler modelValidationActionHandler = new ModelValidationActionHandler( + clusterService, + client, + clientUtil, + (ActionListener) listener, + anomalyDetector, + requestTimeout, + xContentRegistry, + searchFeatureDao, + validationType, + clock, + settings, + user + ); + modelValidationActionHandler.checkIfMultiEntityDetector(); + } + } + + @SuppressWarnings("unchecked") + protected void indexAnomalyDetector(String detectorId) throws IOException { + AnomalyDetector detector = new AnomalyDetector( + anomalyDetector.getId(), + anomalyDetector.getVersion(), + anomalyDetector.getName(), + anomalyDetector.getDescription(), + anomalyDetector.getTimeField(), + anomalyDetector.getIndices(), + anomalyDetector.getFeatureAttributes(), + anomalyDetector.getFilterQuery(), + anomalyDetector.getInterval(), + anomalyDetector.getWindowDelay(), + anomalyDetector.getShingleSize(), + anomalyDetector.getUiMetadata(), + anomalyDetector.getSchemaVersion(), + Instant.now(), + anomalyDetector.getCategoryFields(), + user, + anomalyDetector.getCustomResultIndex(), + anomalyDetector.getImputationOption() + ); + IndexRequest indexRequest = new IndexRequest(CommonName.CONFIG_INDEX) + .setRefreshPolicy(refreshPolicy) + .source(detector.toXContent(XContentFactory.jsonBuilder(), XCONTENT_WITH_TYPE)) + .setIfSeqNo(seqNo) + .setIfPrimaryTerm(primaryTerm) + .timeout(requestTimeout); + if (StringUtils.isNotBlank(detectorId)) { + indexRequest.id(detectorId); + } + + client.index(indexRequest, new ActionListener() { + @Override + public void onResponse(IndexResponse indexResponse) { + String errorMsg = checkShardsFailure(indexResponse); + if (errorMsg != null) { + listener.onFailure(new OpenSearchStatusException(errorMsg, indexResponse.status())); + return; + } + listener + .onResponse( + (T) new IndexAnomalyDetectorResponse( + indexResponse.getId(), + indexResponse.getVersion(), + indexResponse.getSeqNo(), + indexResponse.getPrimaryTerm(), + detector, + RestStatus.CREATED + ) + ); + } + + @Override + public void onFailure(Exception e) { + logger.warn("Failed to update detector", e); + if (e.getMessage() != null && e.getMessage().contains("version conflict")) { + listener + .onFailure( + new IllegalArgumentException("There was a problem updating the historical detector:[" + detectorId + "]") + ); + } else { + listener.onFailure(e); + } + } + }); + } + + protected void onCreateMappingsResponse(CreateIndexResponse response, boolean indexingDryRun) throws IOException { + if (response.isAcknowledged()) { + logger.info("Created {} with mappings.", CommonName.CONFIG_INDEX); + prepareAnomalyDetectorIndexing(indexingDryRun); + } else { + logger.warn("Created {} with mappings call not acknowledged.", CommonName.CONFIG_INDEX); + listener + .onFailure( + new OpenSearchStatusException( + "Created " + CommonName.CONFIG_INDEX + "with mappings call not acknowledged.", + RestStatus.INTERNAL_SERVER_ERROR + ) + ); + } + } + + protected String checkShardsFailure(IndexResponse response) { + StringBuilder failureReasons = new StringBuilder(); + if (response.getShardInfo().getFailed() > 0) { + for (ReplicationResponse.ShardInfo.Failure failure : response.getShardInfo().getFailures()) { + failureReasons.append(failure); + } + return failureReasons.toString(); + } + return null; + } + + /** + * Validate config/syntax, and runtime error of detector features + * @param detectorId detector id + * @param indexingDryRun if false, then will eventually index detector; true, skip indexing detector + * @throws IOException when fail to parse feature aggregation + */ + // TODO: move this method to util class so that it can be re-usable for more use cases + // https://github.com/opensearch-project/anomaly-detection/issues/39 + protected void validateAnomalyDetectorFeatures(String detectorId, boolean indexingDryRun) throws IOException { + if (anomalyDetector != null + && (anomalyDetector.getFeatureAttributes() == null || anomalyDetector.getFeatureAttributes().isEmpty())) { + checkADNameExists(detectorId, indexingDryRun); + return; + } + // checking configuration/syntax error of detector features + String error = RestHandlerUtils.checkFeaturesSyntax(anomalyDetector, maxAnomalyFeatures); + if (StringUtils.isNotBlank(error)) { + if (indexingDryRun) { + listener.onFailure(new ValidationException(error, ValidationIssueType.FEATURE_ATTRIBUTES, ValidationAspect.DETECTOR)); + return; + } + listener.onFailure(new OpenSearchStatusException(error, RestStatus.BAD_REQUEST)); + return; + } + // checking runtime error from feature query + ActionListener>> validateFeatureQueriesListener = ActionListener + .wrap( + response -> { checkADNameExists(detectorId, indexingDryRun); }, + exception -> { + listener + .onFailure( + new ValidationException( + exception.getMessage(), + ValidationIssueType.FEATURE_ATTRIBUTES, + ValidationAspect.DETECTOR + ) + ); + } + ); + MultiResponsesDelegateActionListener>> multiFeatureQueriesResponseListener = + new MultiResponsesDelegateActionListener>>( + validateFeatureQueriesListener, + anomalyDetector.getFeatureAttributes().size(), + String.format(Locale.ROOT, "Validation failed for feature(s) of detector %s", anomalyDetector.getName()), + false + ); + + for (Feature feature : anomalyDetector.getFeatureAttributes()) { + SearchSourceBuilder ssb = new SearchSourceBuilder().size(1).query(QueryBuilders.matchAllQuery()); + AggregatorFactories.Builder internalAgg = parseAggregators( + feature.getAggregation().toString(), + xContentRegistry, + feature.getId() + ); + ssb.aggregation(internalAgg.getAggregatorFactories().iterator().next()); + SearchRequest searchRequest = new SearchRequest().indices(anomalyDetector.getIndices().toArray(new String[0])).source(ssb); + ActionListener searchResponseListener = ActionListener.wrap(response -> { + Optional aggFeatureResult = searchFeatureDao.parseResponse(response, Arrays.asList(feature.getId())); + if (aggFeatureResult.isPresent()) { + multiFeatureQueriesResponseListener + .onResponse( + new MergeableList>(new ArrayList>(Arrays.asList(aggFeatureResult))) + ); + } else { + String errorMessage = CommonMessages.FEATURE_WITH_EMPTY_DATA_MSG + feature.getName(); + logger.error(errorMessage); + multiFeatureQueriesResponseListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.BAD_REQUEST)); + } + }, e -> { + String errorMessage; + if (isExceptionCausedByInvalidQuery(e)) { + errorMessage = CommonMessages.FEATURE_WITH_INVALID_QUERY_MSG + feature.getName(); + } else { + errorMessage = CommonMessages.UNKNOWN_SEARCH_QUERY_EXCEPTION_MSG + feature.getName(); + } + logger.error(errorMessage, e); + multiFeatureQueriesResponseListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.BAD_REQUEST, e)); + }); + clientUtil.asyncRequestWithInjectedSecurity(searchRequest, client::search, user, client, searchResponseListener); + } + } +} diff --git a/src/main/java/org/opensearch/ad/rest/handler/AnomalyDetectorActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/AnomalyDetectorActionHandler.java index ed5376be9..f279f8b63 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/AnomalyDetectorActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/AnomalyDetectorActionHandler.java @@ -11,7 +11,7 @@ package org.opensearch.ad.rest.handler; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; @@ -24,9 +24,9 @@ import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.rest.RestStatus; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.function.ExecutorFunction; import org.opensearch.timeseries.util.RestHandlerUtils; diff --git a/src/main/java/org/opensearch/ad/rest/handler/AnomalyDetectorActionHandler.java-e b/src/main/java/org/opensearch/ad/rest/handler/AnomalyDetectorActionHandler.java-e new file mode 100644 index 000000000..7a1939d3e --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/handler/AnomalyDetectorActionHandler.java-e @@ -0,0 +1,105 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest.handler; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.util.RestHandlerUtils; + +/** + * Common handler to process AD request. + */ +public class AnomalyDetectorActionHandler { + + private final Logger logger = LogManager.getLogger(AnomalyDetectorActionHandler.class); + + /** + * Get detector job for update/delete AD job. + * If AD job exist, will return error message; otherwise, execute function. + * + * @param clusterService ES cluster service + * @param client ES node client + * @param detectorId detector identifier + * @param listener Listener to send response + * @param function AD function + * @param xContentRegistry Registry which is used for XContentParser + */ + public void getDetectorJob( + ClusterService clusterService, + Client client, + String detectorId, + ActionListener listener, + ExecutorFunction function, + NamedXContentRegistry xContentRegistry + ) { + if (clusterService.state().metadata().indices().containsKey(CommonName.JOB_INDEX)) { + GetRequest request = new GetRequest(CommonName.JOB_INDEX).id(detectorId); + client + .get( + request, + ActionListener + .wrap(response -> onGetAdJobResponseForWrite(response, listener, function, xContentRegistry), exception -> { + logger.error("Fail to get anomaly detector job: " + detectorId, exception); + listener.onFailure(exception); + }) + ); + } else { + function.execute(); + } + } + + private void onGetAdJobResponseForWrite( + GetResponse response, + ActionListener listener, + ExecutorFunction function, + NamedXContentRegistry xContentRegistry + ) { + if (response.isExists()) { + String adJobId = response.getId(); + if (adJobId != null) { + // check if AD job is running on the detector, if yes, we can't delete the detector + try ( + XContentParser parser = RestHandlerUtils + .createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + AnomalyDetectorJob adJob = AnomalyDetectorJob.parse(parser); + if (adJob.isEnabled()) { + listener.onFailure(new OpenSearchStatusException("Detector job is running: " + adJobId, RestStatus.BAD_REQUEST)); + return; + } + } catch (IOException e) { + String message = "Failed to parse anomaly detector job " + adJobId; + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.BAD_REQUEST)); + } + } + } + function.execute(); + } +} diff --git a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java-e b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java-e new file mode 100644 index 000000000..b401ce007 --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java-e @@ -0,0 +1,121 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest.handler; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.IndexAnomalyDetectorResponse; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.rest.RestRequest; +import org.opensearch.transport.TransportService; + +/** + * Anomaly detector REST action handler to process POST/PUT request. + * POST request is for creating anomaly detector. + * PUT request is for updating anomaly detector. + */ +public class IndexAnomalyDetectorActionHandler extends AbstractAnomalyDetectorActionHandler { + + /** + * Constructor function. + * + * @param clusterService ClusterService + * @param client ES node client that executes actions on the local node + * @param clientUtil AD client util + * @param transportService ES transport service + * @param listener ES channel used to construct bytes / builder based outputs, and send responses + * @param anomalyDetectionIndices anomaly detector index manager + * @param detectorId detector identifier + * @param seqNo sequence number of last modification + * @param primaryTerm primary term of last modification + * @param refreshPolicy refresh policy + * @param anomalyDetector anomaly detector instance + * @param requestTimeout request time out configuration + * @param maxSingleEntityAnomalyDetectors max single-entity anomaly detectors allowed + * @param maxMultiEntityAnomalyDetectors max multi-entity detectors allowed + * @param maxAnomalyFeatures max features allowed per detector + * @param method Rest Method type + * @param xContentRegistry Registry which is used for XContentParser + * @param user User context + * @param adTaskManager AD Task manager + * @param searchFeatureDao Search feature dao + * @param settings Node settings + */ + public IndexAnomalyDetectorActionHandler( + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + TransportService transportService, + ActionListener listener, + ADIndexManagement anomalyDetectionIndices, + String detectorId, + Long seqNo, + Long primaryTerm, + WriteRequest.RefreshPolicy refreshPolicy, + AnomalyDetector anomalyDetector, + TimeValue requestTimeout, + Integer maxSingleEntityAnomalyDetectors, + Integer maxMultiEntityAnomalyDetectors, + Integer maxAnomalyFeatures, + RestRequest.Method method, + NamedXContentRegistry xContentRegistry, + User user, + ADTaskManager adTaskManager, + SearchFeatureDao searchFeatureDao, + Settings settings + ) { + super( + clusterService, + client, + clientUtil, + transportService, + listener, + anomalyDetectionIndices, + detectorId, + seqNo, + primaryTerm, + refreshPolicy, + anomalyDetector, + requestTimeout, + maxSingleEntityAnomalyDetectors, + maxMultiEntityAnomalyDetectors, + maxAnomalyFeatures, + method, + xContentRegistry, + user, + adTaskManager, + searchFeatureDao, + null, + false, + null, + settings + ); + } + + /** + * Start function to process create/update anomaly detector request. + */ + @Override + public void start() { + super.start(); + } +} diff --git a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java index a81e00a35..824c6fc21 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java @@ -14,7 +14,7 @@ import static org.opensearch.action.DocWriteResponse.Result.CREATED; import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.ad.util.ExceptionUtil.getShardsFailure; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.timeseries.util.RestHandlerUtils.createXContentParserFromRegistry; import java.io.IOException; @@ -45,11 +45,11 @@ import org.opensearch.client.Client; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; import org.opensearch.jobscheduler.spi.schedule.Schedule; -import org.opensearch.rest.RestStatus; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.function.ExecutorFunction; import org.opensearch.timeseries.model.IntervalTimeConfiguration; diff --git a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java-e b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java-e new file mode 100644 index 000000000..a2a55ecdc --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java-e @@ -0,0 +1,434 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest.handler; + +import static org.opensearch.action.DocWriteResponse.Result.CREATED; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.ad.util.ExceptionUtil.getShardsFailure; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.util.RestHandlerUtils.createXContentParserFromRegistry; + +import java.io.IOException; +import java.time.Duration; +import java.time.Instant; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.ad.ExecuteADResultResponseRecorder; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTaskState; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.AnomalyDetectorJobResponse; +import org.opensearch.ad.transport.AnomalyResultAction; +import org.opensearch.ad.transport.AnomalyResultRequest; +import org.opensearch.ad.transport.StopDetectorAction; +import org.opensearch.ad.transport.StopDetectorRequest; +import org.opensearch.ad.transport.StopDetectorResponse; +import org.opensearch.client.Client; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.jobscheduler.spi.schedule.Schedule; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +import com.google.common.base.Throwables; + +/** + * Anomaly detector job REST action handler to process POST/PUT request. + */ +public class IndexAnomalyDetectorJobActionHandler { + + private final ADIndexManagement anomalyDetectionIndices; + private final String detectorId; + private final Long seqNo; + private final Long primaryTerm; + private final Client client; + private final NamedXContentRegistry xContentRegistry; + private final TransportService transportService; + private final ADTaskManager adTaskManager; + + private final Logger logger = LogManager.getLogger(IndexAnomalyDetectorJobActionHandler.class); + private final TimeValue requestTimeout; + private final ExecuteADResultResponseRecorder recorder; + + /** + * Constructor function. + * + * @param client ES node client that executes actions on the local node + * @param anomalyDetectionIndices anomaly detector index manager + * @param detectorId detector identifier + * @param seqNo sequence number of last modification + * @param primaryTerm primary term of last modification + * @param requestTimeout request time out configuration + * @param xContentRegistry Registry which is used for XContentParser + * @param transportService transport service + * @param adTaskManager AD task manager + * @param recorder Utility to record AnomalyResultAction execution result + */ + public IndexAnomalyDetectorJobActionHandler( + Client client, + ADIndexManagement anomalyDetectionIndices, + String detectorId, + Long seqNo, + Long primaryTerm, + TimeValue requestTimeout, + NamedXContentRegistry xContentRegistry, + TransportService transportService, + ADTaskManager adTaskManager, + ExecuteADResultResponseRecorder recorder + ) { + this.client = client; + this.anomalyDetectionIndices = anomalyDetectionIndices; + this.detectorId = detectorId; + this.seqNo = seqNo; + this.primaryTerm = primaryTerm; + this.requestTimeout = requestTimeout; + this.xContentRegistry = xContentRegistry; + this.transportService = transportService; + this.adTaskManager = adTaskManager; + this.recorder = recorder; + } + + /** + * Start anomaly detector job. + * 1. If job doesn't exist, create new job. + * 2. If job exists: a). if job enabled, return error message; b). if job disabled, enable job. + * @param detector anomaly detector + * @param listener Listener to send responses + */ + public void startAnomalyDetectorJob(AnomalyDetector detector, ActionListener listener) { + // this start listener is created & injected throughout the job handler so that whenever the job response is received, + // there's the extra step of trying to index results and update detector state with a 60s delay. + ActionListener startListener = ActionListener.wrap(r -> { + try { + Instant executionEndTime = Instant.now(); + IntervalTimeConfiguration schedule = (IntervalTimeConfiguration) detector.getInterval(); + Instant executionStartTime = executionEndTime.minus(schedule.getInterval(), schedule.getUnit()); + AnomalyResultRequest getRequest = new AnomalyResultRequest( + detector.getId(), + executionStartTime.toEpochMilli(), + executionEndTime.toEpochMilli() + ); + client + .execute( + AnomalyResultAction.INSTANCE, + getRequest, + ActionListener + .wrap( + response -> recorder.indexAnomalyResult(executionStartTime, executionEndTime, response, detector), + exception -> { + + recorder + .indexAnomalyResultException( + executionStartTime, + executionEndTime, + Throwables.getStackTraceAsString(exception), + null, + detector + ); + } + ) + ); + } catch (Exception ex) { + listener.onFailure(ex); + return; + } + listener.onResponse(r); + + }, listener::onFailure); + if (!anomalyDetectionIndices.doesJobIndexExist()) { + anomalyDetectionIndices.initJobIndex(ActionListener.wrap(response -> { + if (response.isAcknowledged()) { + logger.info("Created {} with mappings.", CommonName.CONFIG_INDEX); + createJob(detector, startListener); + } else { + logger.warn("Created {} with mappings call not acknowledged.", CommonName.CONFIG_INDEX); + startListener + .onFailure( + new OpenSearchStatusException( + "Created " + CommonName.CONFIG_INDEX + " with mappings call not acknowledged.", + RestStatus.INTERNAL_SERVER_ERROR + ) + ); + } + }, exception -> startListener.onFailure(exception))); + } else { + createJob(detector, startListener); + } + } + + private void createJob(AnomalyDetector detector, ActionListener listener) { + try { + IntervalTimeConfiguration interval = (IntervalTimeConfiguration) detector.getInterval(); + Schedule schedule = new IntervalSchedule(Instant.now(), (int) interval.getInterval(), interval.getUnit()); + Duration duration = Duration.of(interval.getInterval(), interval.getUnit()); + + AnomalyDetectorJob job = new AnomalyDetectorJob( + detector.getId(), + schedule, + detector.getWindowDelay(), + true, + Instant.now(), + null, + Instant.now(), + duration.getSeconds(), + detector.getUser(), + detector.getCustomResultIndex() + ); + + getAnomalyDetectorJobForWrite(detector, job, listener); + } catch (Exception e) { + String message = "Failed to parse anomaly detector job " + detectorId; + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + } + + private void getAnomalyDetectorJobForWrite( + AnomalyDetector detector, + AnomalyDetectorJob job, + ActionListener listener + ) { + GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX).id(detectorId); + + client + .get( + getRequest, + ActionListener + .wrap( + response -> onGetAnomalyDetectorJobForWrite(response, detector, job, listener), + exception -> listener.onFailure(exception) + ) + ); + } + + private void onGetAnomalyDetectorJobForWrite( + GetResponse response, + AnomalyDetector detector, + AnomalyDetectorJob job, + ActionListener listener + ) throws IOException { + if (response.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + AnomalyDetectorJob currentAdJob = AnomalyDetectorJob.parse(parser); + if (currentAdJob.isEnabled()) { + listener + .onFailure(new OpenSearchStatusException("Anomaly detector job is already running: " + detectorId, RestStatus.OK)); + return; + } else { + AnomalyDetectorJob newJob = new AnomalyDetectorJob( + job.getName(), + job.getSchedule(), + job.getWindowDelay(), + job.isEnabled(), + Instant.now(), + currentAdJob.getDisabledTime(), + Instant.now(), + job.getLockDurationSeconds(), + job.getUser(), + job.getCustomResultIndex() + ); + // Get latest realtime task and check its state before index job. Will reset running realtime task + // as STOPPED first if job disabled, then start new job and create new realtime task. + adTaskManager + .startDetector( + detector, + null, + job.getUser(), + transportService, + ActionListener + .wrap( + r -> { indexAnomalyDetectorJob(newJob, null, listener); }, + e -> { + // Have logged error message in ADTaskManager#startDetector + listener.onFailure(e); + } + ) + ); + } + } catch (IOException e) { + String message = "Failed to parse anomaly detector job " + job.getName(); + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + } else { + adTaskManager + .startDetector( + detector, + null, + job.getUser(), + transportService, + ActionListener.wrap(r -> { indexAnomalyDetectorJob(job, null, listener); }, e -> listener.onFailure(e)) + ); + } + } + + private void indexAnomalyDetectorJob( + AnomalyDetectorJob job, + ExecutorFunction function, + ActionListener listener + ) throws IOException { + IndexRequest indexRequest = new IndexRequest(CommonName.JOB_INDEX) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .source(job.toXContent(XContentFactory.jsonBuilder(), RestHandlerUtils.XCONTENT_WITH_TYPE)) + .setIfSeqNo(seqNo) + .setIfPrimaryTerm(primaryTerm) + .timeout(requestTimeout) + .id(detectorId); + client + .index( + indexRequest, + ActionListener + .wrap( + response -> onIndexAnomalyDetectorJobResponse(response, function, listener), + exception -> listener.onFailure(exception) + ) + ); + } + + private void onIndexAnomalyDetectorJobResponse( + IndexResponse response, + ExecutorFunction function, + ActionListener listener + ) { + if (response == null || (response.getResult() != CREATED && response.getResult() != UPDATED)) { + String errorMsg = getShardsFailure(response); + listener.onFailure(new OpenSearchStatusException(errorMsg, response.status())); + return; + } + if (function != null) { + function.execute(); + } else { + AnomalyDetectorJobResponse anomalyDetectorJobResponse = new AnomalyDetectorJobResponse( + response.getId(), + response.getVersion(), + response.getSeqNo(), + response.getPrimaryTerm(), + RestStatus.OK + ); + listener.onResponse(anomalyDetectorJobResponse); + } + } + + /** + * Stop anomaly detector job. + * 1.If job not exists, return error message + * 2.If job exists: a).if job state is disabled, return error message; b).if job state is enabled, disable job. + * + * @param detectorId detector identifier + * @param listener Listener to send responses + */ + public void stopAnomalyDetectorJob(String detectorId, ActionListener listener) { + GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX).id(detectorId); + + client.get(getRequest, ActionListener.wrap(response -> { + if (response.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + AnomalyDetectorJob job = AnomalyDetectorJob.parse(parser); + if (!job.isEnabled()) { + adTaskManager.stopLatestRealtimeTask(detectorId, ADTaskState.STOPPED, null, transportService, listener); + } else { + AnomalyDetectorJob newJob = new AnomalyDetectorJob( + job.getName(), + job.getSchedule(), + job.getWindowDelay(), + false, + job.getEnabledTime(), + Instant.now(), + Instant.now(), + job.getLockDurationSeconds(), + job.getUser(), + job.getCustomResultIndex() + ); + indexAnomalyDetectorJob( + newJob, + () -> client + .execute( + StopDetectorAction.INSTANCE, + new StopDetectorRequest(detectorId), + stopAdDetectorListener(detectorId, listener) + ), + listener + ); + } + } catch (IOException e) { + String message = "Failed to parse anomaly detector job " + detectorId; + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + } else { + listener.onFailure(new OpenSearchStatusException("Anomaly detector job not exist: " + detectorId, RestStatus.BAD_REQUEST)); + } + }, exception -> listener.onFailure(exception))); + } + + private ActionListener stopAdDetectorListener( + String detectorId, + ActionListener listener + ) { + return new ActionListener() { + @Override + public void onResponse(StopDetectorResponse stopDetectorResponse) { + if (stopDetectorResponse.success()) { + logger.info("AD model deleted successfully for detector {}", detectorId); + // StopDetectorTransportAction will send out DeleteModelAction which will clear all realtime cache. + // Pass null transport service to method "stopLatestRealtimeTask" to not re-clear coordinating node cache. + adTaskManager.stopLatestRealtimeTask(detectorId, ADTaskState.STOPPED, null, null, listener); + } else { + logger.error("Failed to delete AD model for detector {}", detectorId); + // If failed to clear all realtime cache, will try to re-clear coordinating node cache. + adTaskManager + .stopLatestRealtimeTask( + detectorId, + ADTaskState.FAILED, + new OpenSearchStatusException("Failed to delete AD model", RestStatus.INTERNAL_SERVER_ERROR), + transportService, + listener + ); + } + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to delete AD model for detector " + detectorId, e); + // If failed to clear all realtime cache, will try to re-clear coordinating node cache. + adTaskManager + .stopLatestRealtimeTask( + detectorId, + ADTaskState.FAILED, + new OpenSearchStatusException("Failed to execute stop detector action", RestStatus.INTERNAL_SERVER_ERROR), + transportService, + listener + ); + } + }; + } + +} diff --git a/src/main/java/org/opensearch/ad/rest/handler/ModelValidationActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/ModelValidationActionHandler.java index 5be7e9534..ada684808 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/ModelValidationActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/ModelValidationActionHandler.java @@ -47,12 +47,12 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.commons.authuser.User; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.RangeQueryBuilder; -import org.opensearch.rest.RestStatus; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.Aggregations; diff --git a/src/main/java/org/opensearch/ad/rest/handler/ModelValidationActionHandler.java-e b/src/main/java/org/opensearch/ad/rest/handler/ModelValidationActionHandler.java-e new file mode 100644 index 000000000..5d683ec73 --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/handler/ModelValidationActionHandler.java-e @@ -0,0 +1,840 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.rest.handler; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.CONFIG_BUCKET_MINIMUM_SUCCESS_RATE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.INTERVAL_BUCKET_MINIMUM_SUCCESS_RATE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.INTERVAL_RECOMMENDATION_DECREASING_MULTIPLIER; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.INTERVAL_RECOMMENDATION_INCREASING_MULTIPLIER; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_INTERVAL_REC_LENGTH_IN_MINUTES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_TIMES_DECREASING_INTERVAL; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.TOP_VALIDATE_TIMEOUT_IN_MILLIS; + +import java.io.IOException; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.MergeableList; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; +import org.opensearch.ad.util.MultiResponsesDelegateActionListener; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.BucketOrder; +import org.opensearch.search.aggregations.PipelineAggregatorBuilders; +import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; +import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; +import org.opensearch.search.aggregations.bucket.histogram.DateHistogramInterval; +import org.opensearch.search.aggregations.bucket.histogram.Histogram; +import org.opensearch.search.aggregations.bucket.histogram.LongBounds; +import org.opensearch.search.aggregations.bucket.terms.Terms; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.FieldSortBuilder; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.TimeConfiguration; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.util.ParseUtils; + +/** + *

This class executes all validation checks that are not blocking on the 'model' level. + * This mostly involves checking if the data is generally dense enough to complete model training + * which is based on if enough buckets in the last x intervals have at least 1 document present.

+ *

Initially different bucket aggregations are executed with with every configuration applied and with + * different varying intervals in order to find the best interval for the data. If no interval is found with all + * configuration applied then each configuration is tested sequentially for sparsity

+ */ +// TODO: Add more UT and IT +public class ModelValidationActionHandler { + protected static final String AGG_NAME_TOP = "top_agg"; + protected static final String AGGREGATION = "agg"; + protected final AnomalyDetector anomalyDetector; + protected final ClusterService clusterService; + protected final Logger logger = LogManager.getLogger(AbstractAnomalyDetectorActionHandler.class); + protected final TimeValue requestTimeout; + protected final AnomalyDetectorActionHandler handler = new AnomalyDetectorActionHandler(); + protected final Client client; + protected final SecurityClientUtil clientUtil; + protected final NamedXContentRegistry xContentRegistry; + protected final ActionListener listener; + protected final SearchFeatureDao searchFeatureDao; + protected final Clock clock; + protected final String validationType; + protected final Settings settings; + protected final User user; + + /** + * Constructor function. + * + * @param clusterService ClusterService + * @param client ES node client that executes actions on the local node + * @param clientUtil AD client util + * @param listener ES channel used to construct bytes / builder based outputs, and send responses + * @param anomalyDetector anomaly detector instance + * @param requestTimeout request time out configuration + * @param xContentRegistry Registry which is used for XContentParser + * @param searchFeatureDao Search feature DAO + * @param validationType Specified type for validation + * @param clock clock object to know when to timeout + * @param settings Node settings + * @param user User info + */ + public ModelValidationActionHandler( + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + ActionListener listener, + AnomalyDetector anomalyDetector, + TimeValue requestTimeout, + NamedXContentRegistry xContentRegistry, + SearchFeatureDao searchFeatureDao, + String validationType, + Clock clock, + Settings settings, + User user + ) { + this.clusterService = clusterService; + this.client = client; + this.clientUtil = clientUtil; + this.listener = listener; + this.anomalyDetector = anomalyDetector; + this.requestTimeout = requestTimeout; + this.xContentRegistry = xContentRegistry; + this.searchFeatureDao = searchFeatureDao; + this.validationType = validationType; + this.clock = clock; + this.settings = settings; + this.user = user; + } + + // Need to first check if multi entity detector or not before doing any sort of validation. + // If detector is HCAD then we will find the top entity and treat as single entity for + // validation purposes + public void checkIfMultiEntityDetector() { + ActionListener> recommendationListener = ActionListener + .wrap(topEntity -> getLatestDateForValidation(topEntity), exception -> { + listener.onFailure(exception); + logger.error("Failed to get top entity for categorical field", exception); + }); + if (anomalyDetector.isHighCardinality()) { + getTopEntity(recommendationListener); + } else { + recommendationListener.onResponse(Collections.emptyMap()); + } + } + + // For single category HCAD, this method uses bucket aggregation and sort to get the category field + // that have the highest document count in order to use that top entity for further validation + // For multi-category HCADs we use a composite aggregation to find the top fields for the entity + // with the highest doc count. + private void getTopEntity(ActionListener> topEntityListener) { + // Look at data back to the lower bound given the max interval we recommend or one given + long maxIntervalInMinutes = Math.max(MAX_INTERVAL_REC_LENGTH_IN_MINUTES, anomalyDetector.getIntervalInMinutes()); + LongBounds timeRangeBounds = getTimeRangeBounds( + Instant.now().toEpochMilli(), + new IntervalTimeConfiguration(maxIntervalInMinutes, ChronoUnit.MINUTES) + ); + RangeQueryBuilder rangeQuery = new RangeQueryBuilder(anomalyDetector.getTimeField()) + .from(timeRangeBounds.getMin()) + .to(timeRangeBounds.getMax()); + AggregationBuilder bucketAggs; + Map topKeys = new HashMap<>(); + if (anomalyDetector.getCategoryFields().size() == 1) { + bucketAggs = AggregationBuilders + .terms(AGG_NAME_TOP) + .field(anomalyDetector.getCategoryFields().get(0)) + .order(BucketOrder.count(true)); + } else { + bucketAggs = AggregationBuilders + .composite( + AGG_NAME_TOP, + anomalyDetector + .getCategoryFields() + .stream() + .map(f -> new TermsValuesSourceBuilder(f).field(f)) + .collect(Collectors.toList()) + ) + .size(1000) + .subAggregation( + PipelineAggregatorBuilders + .bucketSort("bucketSort", Collections.singletonList(new FieldSortBuilder("_count").order(SortOrder.DESC))) + .size(1) + ); + } + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(rangeQuery) + .aggregation(bucketAggs) + .trackTotalHits(false) + .size(0); + SearchRequest searchRequest = new SearchRequest() + .indices(anomalyDetector.getIndices().toArray(new String[0])) + .source(searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener.wrap(response -> { + Aggregations aggs = response.getAggregations(); + if (aggs == null) { + topEntityListener.onResponse(Collections.emptyMap()); + return; + } + if (anomalyDetector.getCategoryFields().size() == 1) { + Terms entities = aggs.get(AGG_NAME_TOP); + Object key = entities + .getBuckets() + .stream() + .max(Comparator.comparingInt(entry -> (int) entry.getDocCount())) + .map(MultiBucketsAggregation.Bucket::getKeyAsString) + .orElse(null); + topKeys.put(anomalyDetector.getCategoryFields().get(0), key); + } else { + CompositeAggregation compositeAgg = aggs.get(AGG_NAME_TOP); + topKeys + .putAll( + compositeAgg + .getBuckets() + .stream() + .flatMap(bucket -> bucket.getKey().entrySet().stream()) // this would create a flattened stream of map entries + .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())) + ); + } + for (Map.Entry entry : topKeys.entrySet()) { + if (entry.getValue() == null) { + topEntityListener.onResponse(Collections.emptyMap()); + return; + } + } + topEntityListener.onResponse(topKeys); + }, topEntityListener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + searchResponseListener + ); + } + + private void getLatestDateForValidation(Map topEntity) { + ActionListener> latestTimeListener = ActionListener + .wrap(latest -> getSampleRangesForValidationChecks(latest, anomalyDetector, listener, topEntity), exception -> { + listener.onFailure(exception); + logger.error("Failed to create search request for last data point", exception); + }); + searchFeatureDao.getLatestDataTime(anomalyDetector, latestTimeListener); + } + + private void getSampleRangesForValidationChecks( + Optional latestTime, + AnomalyDetector detector, + ActionListener listener, + Map topEntity + ) { + if (!latestTime.isPresent() || latestTime.get() <= 0) { + listener + .onFailure( + new ValidationException( + ADCommonMessages.TIME_FIELD_NOT_ENOUGH_HISTORICAL_DATA, + ValidationIssueType.TIMEFIELD_FIELD, + ValidationAspect.MODEL + ) + ); + return; + } + long timeRangeEnd = Math.min(Instant.now().toEpochMilli(), latestTime.get()); + try { + getBucketAggregates(timeRangeEnd, listener, topEntity); + } catch (IOException e) { + listener.onFailure(new EndRunException(detector.getId(), CommonMessages.INVALID_SEARCH_QUERY_MSG, e, true)); + } + } + + private void getBucketAggregates( + long latestTime, + ActionListener listener, + Map topEntity + ) throws IOException { + AggregationBuilder aggregation = getBucketAggregation(latestTime, (IntervalTimeConfiguration) anomalyDetector.getInterval()); + BoolQueryBuilder query = QueryBuilders.boolQuery().filter(anomalyDetector.getFilterQuery()); + if (anomalyDetector.isHighCardinality()) { + if (topEntity.isEmpty()) { + listener + .onFailure( + new ValidationException( + ADCommonMessages.CATEGORY_FIELD_TOO_SPARSE, + ValidationIssueType.CATEGORY, + ValidationAspect.MODEL + ) + ); + return; + } + for (Map.Entry entry : topEntity.entrySet()) { + query.filter(QueryBuilders.termQuery(entry.getKey(), entry.getValue())); + } + } + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(query) + .aggregation(aggregation) + .size(0) + .timeout(requestTimeout); + SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])).source(searchSourceBuilder); + ActionListener intervalListener = ActionListener + .wrap(interval -> processIntervalRecommendation(interval, latestTime), exception -> { + listener.onFailure(exception); + logger.error("Failed to get interval recommendation", exception); + }); + final ActionListener searchResponseListener = + new ModelValidationActionHandler.DetectorIntervalRecommendationListener( + intervalListener, + searchRequest.source(), + (IntervalTimeConfiguration) anomalyDetector.getInterval(), + clock.millis() + TOP_VALIDATE_TIMEOUT_IN_MILLIS, + latestTime, + false, + MAX_TIMES_DECREASING_INTERVAL + ); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + searchResponseListener + ); + } + + private double processBucketAggregationResults(Histogram buckets) { + int docCountOverOne = 0; + // For each entry + for (Histogram.Bucket entry : buckets.getBuckets()) { + if (entry.getDocCount() > 0) { + docCountOverOne++; + } + } + return (docCountOverOne / (double) getNumberOfSamples()); + } + + /** + * ActionListener class to handle execution of multiple bucket aggregations one after the other + * Bucket aggregation with different interval lengths are executed one by one to check if the data is dense enough + * We only need to execute the next query if the previous one led to data that is too sparse. + */ + class DetectorIntervalRecommendationListener implements ActionListener { + private final ActionListener intervalListener; + SearchSourceBuilder searchSourceBuilder; + IntervalTimeConfiguration detectorInterval; + private final long expirationEpochMs; + private final long latestTime; + boolean decreasingInterval; + int numTimesDecreasing; // maximum amount of times we will try decreasing interval for recommendation + + DetectorIntervalRecommendationListener( + ActionListener intervalListener, + SearchSourceBuilder searchSourceBuilder, + IntervalTimeConfiguration detectorInterval, + long expirationEpochMs, + long latestTime, + boolean decreasingInterval, + int numTimesDecreasing + ) { + this.intervalListener = intervalListener; + this.searchSourceBuilder = searchSourceBuilder; + this.detectorInterval = detectorInterval; + this.expirationEpochMs = expirationEpochMs; + this.latestTime = latestTime; + this.decreasingInterval = decreasingInterval; + this.numTimesDecreasing = numTimesDecreasing; + } + + @Override + public void onResponse(SearchResponse response) { + try { + Histogram aggregate = checkBucketResultErrors(response); + if (aggregate == null) { + return; + } + + long newIntervalMinute; + if (decreasingInterval) { + newIntervalMinute = (long) Math + .floor( + IntervalTimeConfiguration.getIntervalInMinute(detectorInterval) * INTERVAL_RECOMMENDATION_DECREASING_MULTIPLIER + ); + } else { + newIntervalMinute = (long) Math + .ceil( + IntervalTimeConfiguration.getIntervalInMinute(detectorInterval) * INTERVAL_RECOMMENDATION_INCREASING_MULTIPLIER + ); + } + double fullBucketRate = processBucketAggregationResults(aggregate); + // If rate is above success minimum then return interval suggestion. + if (fullBucketRate > INTERVAL_BUCKET_MINIMUM_SUCCESS_RATE) { + intervalListener.onResponse(this.detectorInterval); + } else if (expirationEpochMs < clock.millis()) { + listener + .onFailure( + new ValidationException( + ADCommonMessages.TIMEOUT_ON_INTERVAL_REC, + ValidationIssueType.TIMEOUT, + ValidationAspect.MODEL + ) + ); + logger.info(ADCommonMessages.TIMEOUT_ON_INTERVAL_REC); + // keep trying higher intervals as new interval is below max, and we aren't decreasing yet + } else if (newIntervalMinute < MAX_INTERVAL_REC_LENGTH_IN_MINUTES && !decreasingInterval) { + searchWithDifferentInterval(newIntervalMinute); + // The below block is executed only the first time when new interval is above max and + // we aren't decreasing yet, at this point we will start decreasing for the first time + // if we are inside the below block + } else if (newIntervalMinute >= MAX_INTERVAL_REC_LENGTH_IN_MINUTES && !decreasingInterval) { + IntervalTimeConfiguration givenInterval = (IntervalTimeConfiguration) anomalyDetector.getInterval(); + this.detectorInterval = new IntervalTimeConfiguration( + (long) Math + .floor( + IntervalTimeConfiguration.getIntervalInMinute(givenInterval) * INTERVAL_RECOMMENDATION_DECREASING_MULTIPLIER + ), + ChronoUnit.MINUTES + ); + if (detectorInterval.getInterval() <= 0) { + intervalListener.onResponse(null); + return; + } + this.decreasingInterval = true; + this.numTimesDecreasing -= 1; + // Searching again using an updated interval + SearchSourceBuilder updatedSearchSourceBuilder = getSearchSourceBuilder( + searchSourceBuilder.query(), + getBucketAggregation(this.latestTime, new IntervalTimeConfiguration(newIntervalMinute, ChronoUnit.MINUTES)) + ); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + new SearchRequest() + .indices(anomalyDetector.getIndices().toArray(new String[0])) + .source(updatedSearchSourceBuilder), + client::search, + user, + client, + this + ); + // In this case decreasingInterval has to be true already, so we will stop + // when the next new interval is below or equal to 0, or we have decreased up to max times + } else if (numTimesDecreasing >= 0 && newIntervalMinute > 0) { + this.numTimesDecreasing -= 1; + searchWithDifferentInterval(newIntervalMinute); + // this case means all intervals up to max interval recommendation length and down to either + // 0 or until we tried 10 lower intervals than the one given have been tried + // which further means the next step is to go through A/B validation checks + } else { + intervalListener.onResponse(null); + } + + } catch (Exception e) { + onFailure(e); + } + } + + private void searchWithDifferentInterval(long newIntervalMinuteValue) { + this.detectorInterval = new IntervalTimeConfiguration(newIntervalMinuteValue, ChronoUnit.MINUTES); + // Searching again using an updated interval + SearchSourceBuilder updatedSearchSourceBuilder = getSearchSourceBuilder( + searchSourceBuilder.query(), + getBucketAggregation(this.latestTime, new IntervalTimeConfiguration(newIntervalMinuteValue, ChronoUnit.MINUTES)) + ); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + new SearchRequest().indices(anomalyDetector.getIndices().toArray(new String[0])).source(updatedSearchSourceBuilder), + client::search, + user, + client, + this + ); + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to recommend new interval", e); + listener + .onFailure( + new ValidationException( + ADCommonMessages.MODEL_VALIDATION_FAILED_UNEXPECTEDLY, + ValidationIssueType.AGGREGATION, + ValidationAspect.MODEL + ) + ); + } + } + + private void processIntervalRecommendation(IntervalTimeConfiguration interval, long latestTime) { + // if interval suggestion is null that means no interval could be found with all the configurations + // applied, our next step then is to check density just with the raw data and then add each configuration + // one at a time to try and find root cause of low density + if (interval == null) { + checkRawDataSparsity(latestTime); + } else { + if (interval.equals(anomalyDetector.getInterval())) { + logger.info("Using the current interval there is enough dense data "); + // Check if there is a window delay recommendation if everything else is successful and send exception + if (Instant.now().toEpochMilli() - latestTime > timeConfigToMilliSec(anomalyDetector.getWindowDelay())) { + sendWindowDelayRec(latestTime); + return; + } + // The rate of buckets with at least 1 doc with given interval is above the success rate + listener.onResponse(null); + return; + } + // return response with interval recommendation + listener + .onFailure( + new ValidationException( + ADCommonMessages.DETECTOR_INTERVAL_REC + interval.getInterval(), + ValidationIssueType.DETECTION_INTERVAL, + ValidationAspect.MODEL, + interval + ) + ); + } + } + + private AggregationBuilder getBucketAggregation(long latestTime, IntervalTimeConfiguration detectorInterval) { + return AggregationBuilders + .dateHistogram(AGGREGATION) + .field(anomalyDetector.getTimeField()) + .minDocCount(1) + .hardBounds(getTimeRangeBounds(latestTime, detectorInterval)) + .fixedInterval(DateHistogramInterval.minutes((int) IntervalTimeConfiguration.getIntervalInMinute(detectorInterval))); + } + + private SearchSourceBuilder getSearchSourceBuilder(QueryBuilder query, AggregationBuilder aggregation) { + return new SearchSourceBuilder().query(query).aggregation(aggregation).size(0).timeout(requestTimeout); + } + + private void checkRawDataSparsity(long latestTime) { + AggregationBuilder aggregation = getBucketAggregation(latestTime, (IntervalTimeConfiguration) anomalyDetector.getInterval()); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().aggregation(aggregation).size(0).timeout(requestTimeout); + SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])).source(searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener + .wrap(response -> processRawDataResults(response, latestTime), listener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + searchResponseListener + ); + } + + private Histogram checkBucketResultErrors(SearchResponse response) { + Aggregations aggs = response.getAggregations(); + if (aggs == null) { + // This would indicate some bug or some opensearch core changes that we are not aware of (we don't keep up-to-date with + // the large amounts of changes there). For this reason I'm not throwing a SearchException but instead a validation exception + // which will be converted to validation response. + logger.warn("Unexpected null aggregation."); + listener + .onFailure( + new ValidationException( + ADCommonMessages.MODEL_VALIDATION_FAILED_UNEXPECTEDLY, + ValidationIssueType.AGGREGATION, + ValidationAspect.MODEL + ) + ); + return null; + } + Histogram aggregate = aggs.get(AGGREGATION); + if (aggregate == null) { + listener.onFailure(new IllegalArgumentException("Failed to find valid aggregation result")); + return null; + } + return aggregate; + } + + private void processRawDataResults(SearchResponse response, long latestTime) { + Histogram aggregate = checkBucketResultErrors(response); + if (aggregate == null) { + return; + } + double fullBucketRate = processBucketAggregationResults(aggregate); + if (fullBucketRate < INTERVAL_BUCKET_MINIMUM_SUCCESS_RATE) { + listener + .onFailure( + new ValidationException(ADCommonMessages.RAW_DATA_TOO_SPARSE, ValidationIssueType.INDICES, ValidationAspect.MODEL) + ); + } else { + checkDataFilterSparsity(latestTime); + } + } + + private void checkDataFilterSparsity(long latestTime) { + AggregationBuilder aggregation = getBucketAggregation(latestTime, (IntervalTimeConfiguration) anomalyDetector.getInterval()); + BoolQueryBuilder query = QueryBuilders.boolQuery().filter(anomalyDetector.getFilterQuery()); + SearchSourceBuilder searchSourceBuilder = getSearchSourceBuilder(query, aggregation); + SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])).source(searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener + .wrap(response -> processDataFilterResults(response, latestTime), listener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + searchResponseListener + ); + } + + private void processDataFilterResults(SearchResponse response, long latestTime) { + Histogram aggregate = checkBucketResultErrors(response); + if (aggregate == null) { + return; + } + double fullBucketRate = processBucketAggregationResults(aggregate); + if (fullBucketRate < CONFIG_BUCKET_MINIMUM_SUCCESS_RATE) { + listener + .onFailure( + new ValidationException( + ADCommonMessages.FILTER_QUERY_TOO_SPARSE, + ValidationIssueType.FILTER_QUERY, + ValidationAspect.MODEL + ) + ); + // blocks below are executed if data is dense enough with filter query applied. + // If HCAD then category fields will be added to bucket aggregation to see if they + // are the root cause of the issues and if not the feature queries will be checked for sparsity + } else if (anomalyDetector.isHighCardinality()) { + getTopEntityForCategoryField(latestTime); + } else { + try { + checkFeatureQueryDelegate(latestTime); + } catch (Exception ex) { + logger.error(ex); + listener.onFailure(ex); + } + } + } + + private void getTopEntityForCategoryField(long latestTime) { + ActionListener> getTopEntityListener = ActionListener + .wrap(topEntity -> checkCategoryFieldSparsity(topEntity, latestTime), exception -> { + listener.onFailure(exception); + logger.error("Failed to get top entity for categorical field", exception); + return; + }); + getTopEntity(getTopEntityListener); + } + + private void checkCategoryFieldSparsity(Map topEntity, long latestTime) { + BoolQueryBuilder query = QueryBuilders.boolQuery().filter(anomalyDetector.getFilterQuery()); + for (Map.Entry entry : topEntity.entrySet()) { + query.filter(QueryBuilders.termQuery(entry.getKey(), entry.getValue())); + } + AggregationBuilder aggregation = getBucketAggregation(latestTime, (IntervalTimeConfiguration) anomalyDetector.getInterval()); + SearchSourceBuilder searchSourceBuilder = getSearchSourceBuilder(query, aggregation); + SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])).source(searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener + .wrap(response -> processTopEntityResults(response, latestTime), listener::onFailure); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + searchResponseListener + ); + } + + private void processTopEntityResults(SearchResponse response, long latestTime) { + Histogram aggregate = checkBucketResultErrors(response); + if (aggregate == null) { + return; + } + double fullBucketRate = processBucketAggregationResults(aggregate); + if (fullBucketRate < CONFIG_BUCKET_MINIMUM_SUCCESS_RATE) { + listener + .onFailure( + new ValidationException( + ADCommonMessages.CATEGORY_FIELD_TOO_SPARSE, + ValidationIssueType.CATEGORY, + ValidationAspect.MODEL + ) + ); + } else { + try { + checkFeatureQueryDelegate(latestTime); + } catch (Exception ex) { + logger.error(ex); + listener.onFailure(ex); + } + } + } + + private void checkFeatureQueryDelegate(long latestTime) throws IOException { + ActionListener> validateFeatureQueriesListener = ActionListener + .wrap( + response -> { windowDelayRecommendation(latestTime); }, + exception -> { + listener + .onFailure( + new ValidationException(exception.getMessage(), ValidationIssueType.FEATURE_ATTRIBUTES, ValidationAspect.MODEL) + ); + } + ); + MultiResponsesDelegateActionListener> multiFeatureQueriesResponseListener = + new MultiResponsesDelegateActionListener<>( + validateFeatureQueriesListener, + anomalyDetector.getFeatureAttributes().size(), + ADCommonMessages.FEATURE_QUERY_TOO_SPARSE, + false + ); + + for (Feature feature : anomalyDetector.getFeatureAttributes()) { + AggregationBuilder aggregation = getBucketAggregation(latestTime, (IntervalTimeConfiguration) anomalyDetector.getInterval()); + BoolQueryBuilder query = QueryBuilders.boolQuery().filter(anomalyDetector.getFilterQuery()); + List featureFields = ParseUtils.getFieldNamesForFeature(feature, xContentRegistry); + for (String featureField : featureFields) { + query.filter(QueryBuilders.existsQuery(featureField)); + } + SearchSourceBuilder searchSourceBuilder = getSearchSourceBuilder(query, aggregation); + SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])) + .source(searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener.wrap(response -> { + Histogram aggregate = checkBucketResultErrors(response); + if (aggregate == null) { + return; + } + double fullBucketRate = processBucketAggregationResults(aggregate); + if (fullBucketRate < CONFIG_BUCKET_MINIMUM_SUCCESS_RATE) { + multiFeatureQueriesResponseListener + .onFailure( + new ValidationException( + ADCommonMessages.FEATURE_QUERY_TOO_SPARSE, + ValidationIssueType.FEATURE_ATTRIBUTES, + ValidationAspect.MODEL + ) + ); + } else { + multiFeatureQueriesResponseListener + .onResponse(new MergeableList<>(new ArrayList<>(Collections.singletonList(new double[] { fullBucketRate })))); + } + }, e -> { + logger.error(e); + multiFeatureQueriesResponseListener + .onFailure(new OpenSearchStatusException(ADCommonMessages.FEATURE_QUERY_TOO_SPARSE, RestStatus.BAD_REQUEST, e)); + }); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + user, + client, + searchResponseListener + ); + } + } + + private void sendWindowDelayRec(long latestTimeInMillis) { + long minutesSinceLastStamp = (long) Math.ceil((Instant.now().toEpochMilli() - latestTimeInMillis) / 60000.0); + listener + .onFailure( + new ValidationException( + String.format(Locale.ROOT, ADCommonMessages.WINDOW_DELAY_REC, minutesSinceLastStamp, minutesSinceLastStamp), + ValidationIssueType.WINDOW_DELAY, + ValidationAspect.MODEL, + new IntervalTimeConfiguration(minutesSinceLastStamp, ChronoUnit.MINUTES) + ) + ); + } + + private void windowDelayRecommendation(long latestTime) { + // Check if there is a better window-delay to recommend and if one was recommended + // then send exception and return, otherwise continue to let user know data is too sparse as explained below + if (Instant.now().toEpochMilli() - latestTime > timeConfigToMilliSec(anomalyDetector.getWindowDelay())) { + sendWindowDelayRec(latestTime); + return; + } + // This case has been reached if following conditions are met: + // 1. no interval recommendation was found that leads to a bucket success rate of >= 0.75 + // 2. bucket success rate with the given interval and just raw data is also below 0.75. + // 3. no single configuration during the following checks reduced the bucket success rate below 0.25 + // This means the rate with all configs applied or just raw data was below 0.75 but the rate when checking each configuration at + // a time was always above 0.25 meaning the best suggestion is to simply ingest more data or change interval since + // we have no more insight regarding the root cause of the lower density. + listener + .onFailure(new ValidationException(ADCommonMessages.RAW_DATA_TOO_SPARSE, ValidationIssueType.INDICES, ValidationAspect.MODEL)); + } + + private LongBounds getTimeRangeBounds(long endMillis, IntervalTimeConfiguration detectorIntervalInMinutes) { + Long detectorInterval = timeConfigToMilliSec(detectorIntervalInMinutes); + Long startMillis = endMillis - (getNumberOfSamples() * detectorInterval); + return new LongBounds(startMillis, endMillis); + } + + private int getNumberOfSamples() { + long interval = anomalyDetector.getIntervalInMilliseconds(); + return Math + .max( + (int) (Duration.ofHours(AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS).toMillis() / interval), + AnomalyDetectorSettings.MIN_TRAIN_SAMPLES + ); + } + + private Long timeConfigToMilliSec(TimeConfiguration config) { + return Optional.ofNullable((IntervalTimeConfiguration) config).map(t -> t.toDuration().toMillis()).orElse(0L); + } +} diff --git a/src/main/java/org/opensearch/ad/rest/handler/ValidateAnomalyDetectorActionHandler.java-e b/src/main/java/org/opensearch/ad/rest/handler/ValidateAnomalyDetectorActionHandler.java-e new file mode 100644 index 000000000..163d1df63 --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/handler/ValidateAnomalyDetectorActionHandler.java-e @@ -0,0 +1,115 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest.handler; + +import java.time.Clock; + +import org.opensearch.action.ActionListener; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.rest.RestRequest; + +/** + * Anomaly detector REST action handler to process POST request. + * POST request is for validating anomaly detector against detector and/or model configs. + */ +public class ValidateAnomalyDetectorActionHandler extends AbstractAnomalyDetectorActionHandler { + + /** + * Constructor function. + * + * @param clusterService ClusterService + * @param client ES node client that executes actions on the local node + * @param clientUtil AD client utility + * @param listener ES channel used to construct bytes / builder based outputs, and send responses + * @param anomalyDetectionIndices anomaly detector index manager + * @param anomalyDetector anomaly detector instance + * @param requestTimeout request time out configuration + * @param maxSingleEntityAnomalyDetectors max single-entity anomaly detectors allowed + * @param maxMultiEntityAnomalyDetectors max multi-entity detectors allowed + * @param maxAnomalyFeatures max features allowed per detector + * @param method Rest Method type + * @param xContentRegistry Registry which is used for XContentParser + * @param user User context + * @param searchFeatureDao Search feature DAO + * @param validationType Specified type for validation + * @param clock Clock object to know when to timeout + * @param settings Node settings + */ + public ValidateAnomalyDetectorActionHandler( + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + ActionListener listener, + ADIndexManagement anomalyDetectionIndices, + AnomalyDetector anomalyDetector, + TimeValue requestTimeout, + Integer maxSingleEntityAnomalyDetectors, + Integer maxMultiEntityAnomalyDetectors, + Integer maxAnomalyFeatures, + RestRequest.Method method, + NamedXContentRegistry xContentRegistry, + User user, + SearchFeatureDao searchFeatureDao, + String validationType, + Clock clock, + Settings settings + ) { + super( + clusterService, + client, + clientUtil, + null, + listener, + anomalyDetectionIndices, + AnomalyDetector.NO_ID, + null, + null, + null, + anomalyDetector, + requestTimeout, + maxSingleEntityAnomalyDetectors, + maxMultiEntityAnomalyDetectors, + maxAnomalyFeatures, + method, + xContentRegistry, + user, + null, + searchFeatureDao, + validationType, + true, + clock, + settings + ); + } + + // If validation type is detector then all validation in AbstractAnomalyDetectorActionHandler that is called + // by super.start() involves validation checks against the detector configurations, + // any issues raised here would block user from creating the anomaly detector. + // If validation Aspect is of type model then further non-blocker validation will be executed + // after the blocker validation is executed. Any issues that are raised for model validation + // are simply warnings for the user in terms of how configuration could be changed to lead to + // a higher likelihood of model training completing successfully + @Override + public void start() { + super.start(); + } +} diff --git a/src/main/java/org/opensearch/ad/settings/ADEnabledSetting.java-e b/src/main/java/org/opensearch/ad/settings/ADEnabledSetting.java-e new file mode 100644 index 000000000..ed4414f6c --- /dev/null +++ b/src/main/java/org/opensearch/ad/settings/ADEnabledSetting.java-e @@ -0,0 +1,131 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.settings; + +import static java.util.Collections.unmodifiableMap; +import static org.opensearch.common.settings.Setting.Property.Deprecated; +import static org.opensearch.common.settings.Setting.Property.Dynamic; +import static org.opensearch.common.settings.Setting.Property.NodeScope; + +import java.util.HashMap; +import java.util.Map; + +import org.opensearch.common.settings.Setting; +import org.opensearch.timeseries.settings.DynamicNumericSetting; + +public class ADEnabledSetting extends DynamicNumericSetting { + + /** + * Singleton instance + */ + private static ADEnabledSetting INSTANCE; + + /** + * Settings name + */ + public static final String AD_ENABLED = "plugins.anomaly_detection.enabled"; + + public static final String AD_BREAKER_ENABLED = "plugins.anomaly_detection.breaker.enabled"; + + public static final String LEGACY_OPENDISTRO_AD_ENABLED = "opendistro.anomaly_detection.enabled"; + + public static final String LEGACY_OPENDISTRO_AD_BREAKER_ENABLED = "opendistro.anomaly_detection.breaker.enabled"; + + public static final String INTERPOLATION_IN_HCAD_COLD_START_ENABLED = "plugins.anomaly_detection.hcad_cold_start_interpolation.enabled"; + + public static final String DOOR_KEEPER_IN_CACHE_ENABLED = "plugins.anomaly_detection.door_keeper_in_cache.enabled"; + + public static final Map> settings = unmodifiableMap(new HashMap>() { + { + Setting LegacyADEnabledSetting = Setting.boolSetting(LEGACY_OPENDISTRO_AD_ENABLED, true, NodeScope, Dynamic, Deprecated); + /** + * Legacy OpenDistro AD enable/disable setting + */ + put(LEGACY_OPENDISTRO_AD_ENABLED, LegacyADEnabledSetting); + + Setting LegacyADBreakerEnabledSetting = Setting + .boolSetting(LEGACY_OPENDISTRO_AD_BREAKER_ENABLED, true, NodeScope, Dynamic, Deprecated); + /** + * Legacy OpenDistro AD breaker enable/disable setting + */ + put(LEGACY_OPENDISTRO_AD_BREAKER_ENABLED, LegacyADBreakerEnabledSetting); + + /** + * AD enable/disable setting + */ + put(AD_ENABLED, Setting.boolSetting(AD_ENABLED, LegacyADEnabledSetting, NodeScope, Dynamic)); + + /** + * AD breaker enable/disable setting + */ + put(AD_BREAKER_ENABLED, Setting.boolSetting(AD_BREAKER_ENABLED, LegacyADBreakerEnabledSetting, NodeScope, Dynamic)); + + /** + * Whether interpolation in HCAD cold start is enabled or not + */ + put( + INTERPOLATION_IN_HCAD_COLD_START_ENABLED, + Setting.boolSetting(INTERPOLATION_IN_HCAD_COLD_START_ENABLED, false, NodeScope, Dynamic) + ); + + /** + * We have a bloom filter placed in front of inactive entity cache to + * filter out unpopular items that are not likely to appear more + * than once. Whether this bloom filter is enabled or not. + */ + put(DOOR_KEEPER_IN_CACHE_ENABLED, Setting.boolSetting(DOOR_KEEPER_IN_CACHE_ENABLED, false, NodeScope, Dynamic)); + } + }); + + ADEnabledSetting(Map> settings) { + super(settings); + } + + public static synchronized ADEnabledSetting getInstance() { + if (INSTANCE == null) { + INSTANCE = new ADEnabledSetting(settings); + } + return INSTANCE; + } + + /** + * Whether AD is enabled. If disabled, time series plugin rejects RESTful requests on AD and stop all AD jobs. + * @return whether AD is enabled. + */ + public static boolean isADEnabled() { + return ADEnabledSetting.getInstance().getSettingValue(ADEnabledSetting.AD_ENABLED); + } + + /** + * Whether AD circuit breaker is enabled or not. If disabled, an open circuit breaker wouldn't cause an AD job to be stopped. + * @return whether AD circuit breaker is enabled or not. + */ + public static boolean isADBreakerEnabled() { + return ADEnabledSetting.getInstance().getSettingValue(ADEnabledSetting.AD_BREAKER_ENABLED); + } + + /** + * If enabled, we use samples plus interpolation to train models. + * @return wWhether interpolation in HCAD cold start is enabled or not. + */ + public static boolean isInterpolationInColdStartEnabled() { + return ADEnabledSetting.getInstance().getSettingValue(ADEnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED); + } + + /** + * If enabled, we filter out unpopular items that are not likely to appear more than once + * @return wWhether door keeper in cache is enabled or not. + */ + public static boolean isDoorKeeperInCacheEnabled() { + return ADEnabledSetting.getInstance().getSettingValue(ADEnabledSetting.DOOR_KEEPER_IN_CACHE_ENABLED); + } +} diff --git a/src/main/java/org/opensearch/ad/settings/ADNumericSetting.java-e b/src/main/java/org/opensearch/ad/settings/ADNumericSetting.java-e new file mode 100644 index 000000000..e064867a0 --- /dev/null +++ b/src/main/java/org/opensearch/ad/settings/ADNumericSetting.java-e @@ -0,0 +1,66 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.settings; + +import static java.util.Collections.unmodifiableMap; + +import java.util.HashMap; +import java.util.Map; + +import org.opensearch.common.settings.Setting; +import org.opensearch.timeseries.settings.DynamicNumericSetting; + +public class ADNumericSetting extends DynamicNumericSetting { + + /** + * Singleton instance + */ + private static ADNumericSetting INSTANCE; + + /** + * Settings name + */ + public static final String CATEGORY_FIELD_LIMIT = "plugins.anomaly_detection.category_field_limit"; + + private static final Map> settings = unmodifiableMap(new HashMap>() { + { + // how many categorical fields we support + // The number of category field won't causes correctness issues for our + // implementation, but can cause performance issues. The more categorical + // fields, the larger of the anomaly results, intermediate states, and + // more expensive entities (e.g., to get top entities in preview API, we need + // to use scripts in terms aggregation. The more fields, the slower the query). + put( + CATEGORY_FIELD_LIMIT, + Setting.intSetting(CATEGORY_FIELD_LIMIT, 2, 0, 5, Setting.Property.NodeScope, Setting.Property.Dynamic) + ); + } + }); + + ADNumericSetting(Map> settings) { + super(settings); + } + + public static synchronized ADNumericSetting getInstance() { + if (INSTANCE == null) { + INSTANCE = new ADNumericSetting(settings); + } + return INSTANCE; + } + + /** + * @return the max number of categorical fields + */ + public static int maxCategoricalFields() { + return ADNumericSetting.getInstance().getSettingValue(ADNumericSetting.CATEGORY_FIELD_LIMIT); + } +} diff --git a/src/main/java/org/opensearch/ad/settings/AnomalyDetectorSettings.java-e b/src/main/java/org/opensearch/ad/settings/AnomalyDetectorSettings.java-e new file mode 100644 index 000000000..22e72eba0 --- /dev/null +++ b/src/main/java/org/opensearch/ad/settings/AnomalyDetectorSettings.java-e @@ -0,0 +1,828 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.settings; + +import java.time.Duration; + +import org.opensearch.common.settings.Setting; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + +/** + * AD plugin settings. + */ +public final class AnomalyDetectorSettings { + + private AnomalyDetectorSettings() {} + + public static final int MAX_DETECTOR_UPPER_LIMIT = 10000; + public static final Setting MAX_SINGLE_ENTITY_ANOMALY_DETECTORS = Setting + .intSetting( + "plugins.anomaly_detection.max_anomaly_detectors", + LegacyOpenDistroAnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS, + 0, + MAX_DETECTOR_UPPER_LIMIT, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting MAX_MULTI_ENTITY_ANOMALY_DETECTORS = Setting + .intSetting( + "plugins.anomaly_detection.max_multi_entity_anomaly_detectors", + LegacyOpenDistroAnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS, + 0, + MAX_DETECTOR_UPPER_LIMIT, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting MAX_ANOMALY_FEATURES = Setting + .intSetting( + "plugins.anomaly_detection.max_anomaly_features", + LegacyOpenDistroAnomalyDetectorSettings.MAX_ANOMALY_FEATURES, + 0, + 10, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting REQUEST_TIMEOUT = Setting + .positiveTimeSetting( + "plugins.anomaly_detection.request_timeout", + LegacyOpenDistroAnomalyDetectorSettings.REQUEST_TIMEOUT, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting DETECTION_INTERVAL = Setting + .positiveTimeSetting( + "plugins.anomaly_detection.detection_interval", + LegacyOpenDistroAnomalyDetectorSettings.DETECTION_INTERVAL, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting DETECTION_WINDOW_DELAY = Setting + .timeSetting( + "plugins.anomaly_detection.detection_window_delay", + LegacyOpenDistroAnomalyDetectorSettings.DETECTION_WINDOW_DELAY, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting AD_RESULT_HISTORY_ROLLOVER_PERIOD = Setting + .positiveTimeSetting( + "plugins.anomaly_detection.ad_result_history_rollover_period", + LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // Opensearch-only setting. Doesn't plan to use the value of the legacy setting + // AD_RESULT_HISTORY_MAX_DOCS as that's too low. If the clusterManager node uses opendistro code, + // it uses the legacy setting. If the clusterManager node uses opensearch code, it uses the new setting. + public static final Setting AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD = Setting + .longSetting( + "plugins.anomaly_detection.ad_result_history_max_docs_per_shard", + // Total documents in the primary shards. + // Note the count is for Lucene docs. Lucene considers a nested + // doc a doc too. One result corresponding to 4 Lucene docs. + // A single Lucene doc is roughly 46.8 bytes (measured by experiments). + // 1.35 billion docs is about 65 GB. One shard can have at most 65 GB. + // This number in Lucene doc count is used in RolloverRequest#addMaxIndexDocsCondition + // for adding condition to check if the index has at least numDocs. + 1_350_000_000L, + 0L, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting AD_RESULT_HISTORY_RETENTION_PERIOD = Setting + .positiveTimeSetting( + "plugins.anomaly_detection.ad_result_history_retention_period", + LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting MAX_RETRY_FOR_UNRESPONSIVE_NODE = Setting + .intSetting( + "plugins.anomaly_detection.max_retry_for_unresponsive_node", + LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE, + 0, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting COOLDOWN_MINUTES = Setting + .positiveTimeSetting( + "plugins.anomaly_detection.cooldown_minutes", + LegacyOpenDistroAnomalyDetectorSettings.COOLDOWN_MINUTES, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting BACKOFF_MINUTES = Setting + .positiveTimeSetting( + "plugins.anomaly_detection.backoff_minutes", + LegacyOpenDistroAnomalyDetectorSettings.BACKOFF_MINUTES, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting AD_BACKOFF_INITIAL_DELAY = Setting + .positiveTimeSetting( + "plugins.anomaly_detection.backoff_initial_delay", + LegacyOpenDistroAnomalyDetectorSettings.BACKOFF_INITIAL_DELAY, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting AD_MAX_RETRY_FOR_BACKOFF = Setting + .intSetting( + "plugins.anomaly_detection.max_retry_for_backoff", + LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_BACKOFF, + 0, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting MAX_RETRY_FOR_END_RUN_EXCEPTION = Setting + .intSetting( + "plugins.anomaly_detection.max_retry_for_end_run_exception", + LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_END_RUN_EXCEPTION, + 0, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting FILTER_BY_BACKEND_ROLES = Setting + .boolSetting( + "plugins.anomaly_detection.filter_by_backend_roles", + LegacyOpenDistroAnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final String ANOMALY_RESULTS_INDEX_MAPPING_FILE = "mappings/anomaly-results.json"; + public static final String ANOMALY_DETECTION_STATE_INDEX_MAPPING_FILE = "mappings/anomaly-detection-state.json"; + public static final String CHECKPOINT_INDEX_MAPPING_FILE = "mappings/anomaly-checkpoint.json"; + + public static final Duration HOURLY_MAINTENANCE = Duration.ofHours(1); + + // saving checkpoint every 12 hours. + // To support 1 million entities in 36 data nodes, each node has roughly 28K models. + // In each hour, we roughly need to save 2400 models. Since each model saving can + // take about 1 seconds (default value of AnomalyDetectorSettings.EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_SECS) + // we can use up to 2400 seconds to finish saving checkpoints. + public static final Setting CHECKPOINT_SAVING_FREQ = Setting + .positiveTimeSetting( + "plugins.anomaly_detection.checkpoint_saving_freq", + TimeValue.timeValueHours(12), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting CHECKPOINT_TTL = Setting + .positiveTimeSetting( + "plugins.anomaly_detection.checkpoint_ttl", + TimeValue.timeValueDays(7), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // ====================================== + // ML parameters + // ====================================== + // RCF + public static final int NUM_SAMPLES_PER_TREE = 256; + + public static final int NUM_TREES = 30; + + public static final int TRAINING_SAMPLE_INTERVAL = 64; + + public static final double TIME_DECAY = 0.0001; + + // If we have 32 + shingleSize (hopefully recent) values, RCF can get up and running. It will be noisy — + // there is a reason that default size is 256 (+ shingle size), but it may be more useful for people to + /// start seeing some results. + public static final int NUM_MIN_SAMPLES = 32; + + // The threshold for splitting RCF models in single-stream detectors. + // The smallest machine in the Amazon managed service has 1GB heap. + // With the setting, the desired model size there is of 2 MB. + // By default, we can have at most 5 features. Since the default shingle size + // is 8, we have at most 40 dimensions in RCF. In our current RCF setting, + // 30 trees, and bounding box cache ratio 0, 40 dimensions use 449KB. + // Users can increase the number of features to 10 and shingle size to 60, + // 30 trees, bounding box cache ratio 0, 600 dimensions use 1.8 MB. + // Since these sizes are smaller than the threshold 2 MB, we won't split models + // even in the smallest machine. + public static final double DESIRED_MODEL_SIZE_PERCENTAGE = 0.002; + + public static final Setting MODEL_MAX_SIZE_PERCENTAGE = Setting + .doubleSetting( + "plugins.anomaly_detection.model_max_size_percent", + LegacyOpenDistroAnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE, + 0, + 0.7, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // for a batch operation, we want all of the bounding box in-place for speed + public static final double BATCH_BOUNDING_BOX_CACHE_RATIO = 1; + + // Thresholding + public static final double THRESHOLD_MIN_PVALUE = 0.995; + + public static final double THRESHOLD_MAX_RANK_ERROR = 0.0001; + + public static final double THRESHOLD_MAX_SCORE = 8; + + public static final int THRESHOLD_NUM_LOGNORMAL_QUANTILES = 400; + + public static final int THRESHOLD_DOWNSAMPLES = 5_000; + + public static final long THRESHOLD_MAX_SAMPLES = 50_000; + + // Feature processing + public static final int MAX_TRAIN_SAMPLE = 24; + + public static final int MAX_SAMPLE_STRIDE = 64; + + public static final int TRAIN_SAMPLE_TIME_RANGE_IN_HOURS = 24; + + public static final int MIN_TRAIN_SAMPLES = 512; + + public static final int MAX_IMPUTATION_NEIGHBOR_DISTANCE = 2; + + // shingling + public static final double MAX_SHINGLE_PROPORTION_MISSING = 0.25; + + // AD JOB + public static final long DEFAULT_AD_JOB_LOC_DURATION_SECONDS = 60; + + // Thread pool + public static final int AD_THEAD_POOL_QUEUE_SIZE = 1000; + + // multi-entity caching + public static final int MAX_ACTIVE_STATES = 1000; + + // the size of the cache for small states like last cold start time for an entity. + // At most, we have 10 multi-entity detector and each one can be hit by 1000 different entities each + // minute. Since these states' life time is hour, we keep its size 10 * 1000 = 10000. + public static final int MAX_SMALL_STATES = 10000; + + // ====================================== + // cache related parameters + // ====================================== + /* + * Opensearch-only setting + * Each detector has its dedicated cache that stores ten entities' states per node. + * A detector's hottest entities load their states into the dedicated cache. + * Other detectors cannot use space reserved by a detector's dedicated cache. + * DEDICATED_CACHE_SIZE is a setting to make dedicated cache's size flexible. + * When that setting is changed, if the size decreases, we will release memory + * if required (e.g., when a user also decreased AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE, + * the max memory percentage that AD can use); + * if the size increases, we may reject the setting change if we cannot fulfill + * that request (e.g., when it will uses more memory than allowed for AD). + * + * With compact rcf, rcf with 30 trees and shingle size 4 is of 500KB. + * The recommended max heap size is 32 GB. Even if users use all of the heap + * for AD, the max number of entity model cannot surpass + * 3.2 GB/500KB = 3.2 * 10^10 / 5*10^5 = 6.4 * 10 ^4 + * where 3.2 GB is from 10% memory limit of AD plugin. + * That's why I am using 60_000 as the max limit. + */ + public static final Setting DEDICATED_CACHE_SIZE = Setting + .intSetting("plugins.anomaly_detection.dedicated_cache_size", 10, 0, 60_000, Setting.Property.NodeScope, Setting.Property.Dynamic); + + // We only keep priority (4 bytes float) in inactive cache. 1 million priorities + // take up 4 MB. + public static final int MAX_INACTIVE_ENTITIES = 1_000_000; + + // Increase the value will adding pressure to indexing anomaly results and our feature query + // OpenSearch-only setting as previous the legacy default is too low (1000) + public static final Setting MAX_ENTITIES_PER_QUERY = Setting + .intSetting( + "plugins.anomaly_detection.max_entities_per_query", + 1_000_000, + 0, + 2_000_000, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // save partial zero-anomaly grade results after indexing pressure reaching the limit + // Opendistro version has similar setting. I lowered the value to make room + // for INDEX_PRESSURE_HARD_LIMIT. I don't find a floatSetting that has both default + // and fallback values. I want users to use the new default value 0.6 instead of 0.8. + // So do not plan to use the value of legacy setting as fallback. + public static final Setting AD_INDEX_PRESSURE_SOFT_LIMIT = Setting + .floatSetting( + "plugins.anomaly_detection.index_pressure_soft_limit", + 0.6f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // save only error or larger-than-one anomaly grade results after indexing + // pressure reaching the limit + // opensearch-only setting + public static final Setting AD_INDEX_PRESSURE_HARD_LIMIT = Setting + .floatSetting( + "plugins.anomaly_detection.index_pressure_hard_limit", + 0.9f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // max number of primary shards of an AD index + public static final Setting AD_MAX_PRIMARY_SHARDS = Setting + .intSetting( + "plugins.anomaly_detection.max_primary_shards", + LegacyOpenDistroAnomalyDetectorSettings.MAX_PRIMARY_SHARDS, + 0, + 200, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // max entity value's length + public static int MAX_ENTITY_LENGTH = 256; + + // number of bulk checkpoints per second + public static double CHECKPOINT_BULK_PER_SECOND = 0.02; + + // ====================================== + // Historical analysis + // ====================================== + // Maximum number of batch tasks running on one node. + // TODO: performance test and tune the setting. + public static final Setting MAX_BATCH_TASK_PER_NODE = Setting + .intSetting( + "plugins.anomaly_detection.max_batch_task_per_node", + LegacyOpenDistroAnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE, + 1, + 100, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // Maximum number of deleted tasks can keep in cache. + public static final Setting MAX_CACHED_DELETED_TASKS = Setting + .intSetting( + "plugins.anomaly_detection.max_cached_deleted_tasks", + 1000, + 1, + 10_000, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // Maximum number of old AD tasks we can keep. + public static int MAX_OLD_AD_TASK_DOCS = 1000; + public static final Setting MAX_OLD_AD_TASK_DOCS_PER_DETECTOR = Setting + .intSetting( + "plugins.anomaly_detection.max_old_ad_task_docs_per_detector", + // One AD task is roughly 1.5KB for normal case. Suppose task's size + // is 2KB conservatively. If we store 1000 AD tasks for one detector, + // that will be 2GB. + LegacyOpenDistroAnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR, + 1, // keep at least 1 old AD task per detector + MAX_OLD_AD_TASK_DOCS, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting BATCH_TASK_PIECE_SIZE = Setting + .intSetting( + "plugins.anomaly_detection.batch_task_piece_size", + LegacyOpenDistroAnomalyDetectorSettings.BATCH_TASK_PIECE_SIZE, + 1, + TimeSeriesSettings.MAX_BATCH_TASK_PIECE_SIZE, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting BATCH_TASK_PIECE_INTERVAL_SECONDS = Setting + .intSetting( + "plugins.anomaly_detection.batch_task_piece_interval_seconds", + LegacyOpenDistroAnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS, + 1, + 600, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // Maximum number of entities we support for historical analysis. + public static final int MAX_TOP_ENTITIES_LIMIT_FOR_HISTORICAL_ANALYSIS = 10_000; + public static final Setting MAX_TOP_ENTITIES_FOR_HISTORICAL_ANALYSIS = Setting + .intSetting( + "plugins.anomaly_detection.max_top_entities_for_historical_analysis", + 1000, + 1, + MAX_TOP_ENTITIES_LIMIT_FOR_HISTORICAL_ANALYSIS, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS = Setting + .intSetting( + "plugins.anomaly_detection.max_running_entities_per_detector_for_historical_analysis", + 10, + 1, + 1000, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // ====================================== + // rate-limiting queue parameters + // ====================================== + // the percentage of heap usage allowed for queues holding small requests + // set it to 0 to disable the queue + public static final Setting COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT = Setting + .floatSetting( + "plugins.anomaly_detection.cold_entity_queue_max_heap_percent", + 0.001f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT = Setting + .floatSetting( + "plugins.anomaly_detection.checkpoint_read_queue_max_heap_percent", + 0.001f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT = Setting + .floatSetting( + "plugins.anomaly_detection.entity_cold_start_queue_max_heap_percent", + 0.001f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // the percentage of heap usage allowed for queues holding large requests + // set it to 0 to disable the queue + public static final Setting CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT = Setting + .floatSetting( + "plugins.anomaly_detection.checkpoint_write_queue_max_heap_percent", + 0.01f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT = Setting + .floatSetting( + "plugins.anomaly_detection.result_write_queue_max_heap_percent", + 0.01f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT = Setting + .floatSetting( + "plugins.anomaly_detection.checkpoint_maintain_queue_max_heap_percent", + 0.001f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // expected execution time per cold entity request. This setting controls + // the speed of cold entity requests execution. The larger, the faster, and + // the more performance impact to customers' workload. + public static final Setting EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS = Setting + .intSetting( + "plugins.anomaly_detection.expected_cold_entity_execution_time_in_millisecs", + 3000, + 0, + 3600000, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // expected execution time per checkpoint maintain request. This setting controls + // the speed of checkpoint maintenance execution. The larger, the faster, and + // the more performance impact to customers' workload. + public static final Setting AD_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS = Setting + .intSetting( + "plugins.anomaly_detection.expected_checkpoint_maintain_time_in_millisecs", + 1000, + 0, + 3600000, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * EntityRequest has entityName (# category fields * 256, the recommended limit + * of a keyword field length), model Id (roughly 256 bytes), and QueuedRequest + * fields including detector Id(roughly 128 bytes), expirationEpochMs (long, + * 8 bytes), and priority (12 bytes). + * Plus Java object size (12 bytes), we have roughly 928 bytes per request + * assuming we have 2 categorical fields (plan to support 2 categorical fields now). + * We don't want the total size exceeds 0.1% of the heap. + * We can have at most 0.1% heap / 928 = heap / 928,000. + * For t3.small, 0.1% heap is of 1MB. The queue's size is up to + * 10^ 6 / 928 = 1078 + */ + public static int ENTITY_REQUEST_SIZE_IN_BYTES = 928; + + /** + * EntityFeatureRequest consists of EntityRequest (928 bytes, read comments + * of ENTITY_COLD_START_QUEUE_SIZE_CONSTANT), pointer to current feature + * (8 bytes), and dataStartTimeMillis (8 bytes). We have roughly + * 928 + 16 = 944 bytes per request. + * + * We don't want the total size exceeds 0.1% of the heap. + * We should have at most 0.1% heap / 944 = heap / 944,000 + * For t3.small, 0.1% heap is of 1MB. The queue's size is up to + * 10^ 6 / 944 = 1059 + */ + public static int ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES = 944; + + /** + * ResultWriteRequest consists of index request (roughly 1KB), and QueuedRequest + * fields (148 bytes, read comments of ENTITY_REQUEST_SIZE_CONSTANT). + * Plus Java object size (12 bytes), we have roughly 1160 bytes per request + * + * We don't want the total size exceeds 1% of the heap. + * We should have at most 1% heap / 1148 = heap / 116,000 + * For t3.small, 1% heap is of 10MB. The queue's size is up to + * 10^ 7 / 1160 = 8621 + */ + public static int RESULT_WRITE_QUEUE_SIZE_IN_BYTES = 1160; + + /** + * CheckpointWriteRequest consists of IndexRequest (200 KB), and QueuedRequest + * fields (148 bytes, read comments of ENTITY_REQUEST_SIZE_CONSTANT). + * The total is roughly 200 KB per request. + * + * We don't want the total size exceeds 1% of the heap. + * We should have at most 1% heap / 200KB = heap / 20,000,000 + * For t3.small, 1% heap is of 10MB. The queue's size is up to + * 10^ 7 / 2.0 * 10^5 = 50 + */ + public static int CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES = 200_000; + + /** + * CheckpointMaintainRequest has model Id (roughly 256 bytes), and QueuedRequest + * fields including detector Id(roughly 128 bytes), expirationEpochMs (long, + * 8 bytes), and priority (12 bytes). + * Plus Java object size (12 bytes), we have roughly 416 bytes per request. + * We don't want the total size exceeds 0.1% of the heap. + * We can have at most 0.1% heap / 416 = heap / 416,000. + * For t3.small, 0.1% heap is of 1MB. The queue's size is up to + * 10^ 6 / 416 = 2403 + */ + public static int CHECKPOINT_MAINTAIN_REQUEST_SIZE_IN_BYTES = 416; + + /** + * Max concurrent entity cold starts per node + */ + public static final Setting ENTITY_COLD_START_QUEUE_CONCURRENCY = Setting + .intSetting( + "plugins.anomaly_detection.entity_cold_start_queue_concurrency", + 1, + 1, + 10, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * Max concurrent checkpoint reads per node + */ + public static final Setting AD_CHECKPOINT_READ_QUEUE_CONCURRENCY = Setting + .intSetting( + "plugins.anomaly_detection.checkpoint_read_queue_concurrency", + 1, + 1, + 10, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * Max concurrent checkpoint writes per node + */ + public static final Setting AD_CHECKPOINT_WRITE_QUEUE_CONCURRENCY = Setting + .intSetting( + "plugins.anomaly_detection.checkpoint_write_queue_concurrency", + 2, + 1, + 10, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * Max concurrent result writes per node. Since checkpoint is relatively large + * (250KB), we have 2 concurrent threads processing the queue. + */ + public static final Setting AD_RESULT_WRITE_QUEUE_CONCURRENCY = Setting + .intSetting( + "plugins.anomaly_detection.result_write_queue_concurrency", + 2, + 1, + 10, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * Assume each checkpoint takes roughly 200KB. 25 requests are of 5 MB. + */ + public static final Setting AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE = Setting + .intSetting( + "plugins.anomaly_detection.checkpoint_read_queue_batch_size", + 25, + 1, + 60, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * ES recommends bulk size to be 5~15 MB. + * ref: https://tinyurl.com/3zdbmbwy + * Assume each checkpoint takes roughly 200KB. 25 requests are of 5 MB. + */ + public static final Setting AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE = Setting + .intSetting( + "plugins.anomaly_detection.checkpoint_write_queue_batch_size", + 25, + 1, + 60, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * ES recommends bulk size to be 5~15 MB. + * ref: https://tinyurl.com/3zdbmbwy + * Assume each result takes roughly 1KB. 5000 requests are of 5 MB. + */ + public static final Setting AD_RESULT_WRITE_QUEUE_BATCH_SIZE = Setting + .intSetting( + "plugins.anomaly_detection.result_write_queue_batch_size", + 5000, + 1, + 15000, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Duration QUEUE_MAINTENANCE = Duration.ofMinutes(10); + + public static final float MAX_QUEUED_TASKS_RATIO = 0.5f; + + public static final float MEDIUM_SEGMENT_PRUNE_RATIO = 0.1f; + + public static final float LOW_SEGMENT_PRUNE_RATIO = 0.3f; + + // expensive maintenance (e.g., queue maintenance) with 1/10000 probability + public static final int MAINTENANCE_FREQ_CONSTANT = 10000; + + // ====================================== + // Checkpoint setting + // ====================================== + // we won't accept a checkpoint larger than 30MB. Or we risk OOM. + // For reference, in RCF 1.0, the checkpoint of a RCF with 50 trees, 10 dimensions, + // 256 samples is of 3.2MB. + // In compact rcf, the same RCF is of 163KB. + // Since we allow at most 5 features, and the default shingle size is 8 and default + // tree number size is 100, we can have at most 25.6 MB in RCF 1.0. + // It is possible that cx increases the max features or shingle size, but we don't want + // to risk OOM for the flexibility. + public static final int MAX_CHECKPOINT_BYTES = 30_000_000; + + // Sets the cap on the number of buffer that can be allocated by the rcf deserialization + // buffer pool. Each buffer is of 512 bytes. Memory occupied by 20 buffers is 10.24 KB. + public static final int MAX_TOTAL_RCF_SERIALIZATION_BUFFERS = 20; + + // the size of the buffer used for rcf deserialization + public static final int SERIALIZATION_BUFFER_BYTES = 512; + + // ====================================== + // pagination setting + // ====================================== + // pagination size + public static final Setting PAGE_SIZE = Setting + .intSetting("plugins.anomaly_detection.page_size", 1_000, 0, 10_000, Setting.Property.NodeScope, Setting.Property.Dynamic); + + // within an interval, how many percents are used to process requests. + // 1.0 means we use all of the detection interval to process requests. + // to ensure we don't block next interval, it is better to set it less than 1.0. + public static final float INTERVAL_RATIO_FOR_REQUESTS = 0.9f; + + // ====================================== + // preview setting + // ====================================== + public static final int MIN_PREVIEW_SIZE = 400; // ok to lower + + public static final double PREVIEW_SAMPLE_RATE = 0.25; // ok to adjust, higher for more data, lower for lower latency + + public static final int MAX_PREVIEW_SAMPLES = 300; // ok to adjust, higher for more data, lower for lower latency + + public static final int MAX_PREVIEW_RESULTS = 1_000; // ok to adjust, higher for more data, lower for lower latency + + // Maximum number of entities retrieved for Preview API + // Not using legacy value 30 as default. + // Setting default value to 30 of 2-categorical field detector causes heavy GC + // (half of the time is GC on my 1GB heap machine). This is because we run concurrent + // feature aggregations/training/prediction. + // Default value 5 won't cause heavy GC on an 1-GB heap JVM. + // Since every entity is likely to give some anomalies, 5 entities are enough. + public static final Setting MAX_ENTITIES_FOR_PREVIEW = Setting + .intSetting("plugins.anomaly_detection.max_entities_for_preview", 5, 1, 30, Setting.Property.NodeScope, Setting.Property.Dynamic); + + // max concurrent preview to limit resource usage + public static final Setting MAX_CONCURRENT_PREVIEW = Setting + .intSetting("plugins.anomaly_detection.max_concurrent_preview", 2, 1, 20, Setting.Property.NodeScope, Setting.Property.Dynamic); + + // preview timeout in terms of milliseconds + public static final long PREVIEW_TIMEOUT_IN_MILLIS = 60_000; + + // ====================================== + // top anomaly result API setting + // ====================================== + public static final long TOP_ANOMALY_RESULT_TIMEOUT_IN_MILLIS = 60_000; + + // ====================================== + // cleanup resouce setting + // ====================================== + public static final Setting DELETE_AD_RESULT_WHEN_DELETE_DETECTOR = Setting + .boolSetting( + "plugins.anomaly_detection.delete_anomaly_result_when_delete_detector", + false, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // ====================================== + // stats/profile API setting + // ====================================== + // the max number of models to return per node. + // the setting is used to limit resource usage due to showing models + public static final Setting MAX_MODEL_SIZE_PER_NODE = Setting + .intSetting( + "plugins.anomaly_detection.max_model_size_per_node", + 100, + 1, + 10_000, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // profile API needs to report total entities. We can use cardinality aggregation for a single-category field. + // But we cannot do that for multi-category fields as it requires scripting to generate run time fields, + // which is expensive. We work around the problem by using a composite query to find the first 10_000 buckets. + // Generally, traversing all buckets/combinations can't be done without visiting all matches, which is costly + // for data with many entities. Given that it is often enough to have a lower bound of the number of entities, + // such as "there are at least 10000 entities", the default is set to 10,000. That is, requests will count the + // total entities up to 10,000. + public static final int MAX_TOTAL_ENTITIES_TO_TRACK = 10_000; + + // ====================================== + // Cold start setting + // ====================================== + public static int MAX_COLD_START_ROUNDS = 2; + + // ====================================== + // Validate Detector API setting + // ====================================== + public static final long TOP_VALIDATE_TIMEOUT_IN_MILLIS = 10_000; + public static final long MAX_INTERVAL_REC_LENGTH_IN_MINUTES = 60L; + public static final double INTERVAL_RECOMMENDATION_INCREASING_MULTIPLIER = 1.2; + public static final double INTERVAL_RECOMMENDATION_DECREASING_MULTIPLIER = 0.8; + public static final double INTERVAL_BUCKET_MINIMUM_SUCCESS_RATE = 0.75; + public static final double CONFIG_BUCKET_MINIMUM_SUCCESS_RATE = 0.25; + // This value is set to decrease the number of times we decrease the interval when recommending a new one + // The reason we need a max is because user could give an arbitrarly large interval where we don't know even + // with multiplying the interval down how many intervals will be tried. + public static final int MAX_TIMES_DECREASING_INTERVAL = 10; +} diff --git a/src/main/java/org/opensearch/ad/settings/LegacyOpenDistroAnomalyDetectorSettings.java-e b/src/main/java/org/opensearch/ad/settings/LegacyOpenDistroAnomalyDetectorSettings.java-e new file mode 100644 index 000000000..d8ca4b777 --- /dev/null +++ b/src/main/java/org/opensearch/ad/settings/LegacyOpenDistroAnomalyDetectorSettings.java-e @@ -0,0 +1,312 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.settings; + +import org.opensearch.common.settings.Setting; +import org.opensearch.common.unit.TimeValue; + +/** + * Legacy OpenDistro AD plugin settings. + */ +public class LegacyOpenDistroAnomalyDetectorSettings { + + private LegacyOpenDistroAnomalyDetectorSettings() {} + + public static final Setting MAX_SINGLE_ENTITY_ANOMALY_DETECTORS = Setting + .intSetting( + "opendistro.anomaly_detection.max_anomaly_detectors", + 1000, + 0, + 10_000, + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + public static final Setting MAX_MULTI_ENTITY_ANOMALY_DETECTORS = Setting + .intSetting( + "opendistro.anomaly_detection.max_multi_entity_anomaly_detectors", + 10, + 0, + 10_000, + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + public static final Setting MAX_ANOMALY_FEATURES = Setting + .intSetting( + "opendistro.anomaly_detection.max_anomaly_features", + 5, + 0, + 100, + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + public static final Setting REQUEST_TIMEOUT = Setting + .positiveTimeSetting( + "opendistro.anomaly_detection.request_timeout", + TimeValue.timeValueSeconds(10), + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + public static final Setting DETECTION_INTERVAL = Setting + .positiveTimeSetting( + "opendistro.anomaly_detection.detection_interval", + TimeValue.timeValueMinutes(10), + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + public static final Setting DETECTION_WINDOW_DELAY = Setting + .timeSetting( + "opendistro.anomaly_detection.detection_window_delay", + TimeValue.timeValueMinutes(0), + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + public static final Setting AD_RESULT_HISTORY_ROLLOVER_PERIOD = Setting + .positiveTimeSetting( + "opendistro.anomaly_detection.ad_result_history_rollover_period", + TimeValue.timeValueHours(12), + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + public static final Setting AD_RESULT_HISTORY_MAX_DOCS = Setting + .longSetting( + "opendistro.anomaly_detection.ad_result_history_max_docs", + // Total documents in primary replica. + // A single feature result is roughly 150 bytes. Suppose a doc is + // of 200 bytes, 250 million docs is of 50 GB. We choose 50 GB + // because we have 1 shard at least. One shard can have at most 50 GB. + 250_000_000L, + 0L, + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + public static final Setting AD_RESULT_HISTORY_RETENTION_PERIOD = Setting + .positiveTimeSetting( + "opendistro.anomaly_detection.ad_result_history_retention_period", + TimeValue.timeValueDays(30), + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + public static final Setting MAX_RETRY_FOR_UNRESPONSIVE_NODE = Setting + .intSetting( + "opendistro.anomaly_detection.max_retry_for_unresponsive_node", + 5, + 0, + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + public static final Setting COOLDOWN_MINUTES = Setting + .positiveTimeSetting( + "opendistro.anomaly_detection.cooldown_minutes", + TimeValue.timeValueMinutes(5), + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + public static final Setting BACKOFF_MINUTES = Setting + .positiveTimeSetting( + "opendistro.anomaly_detection.backoff_minutes", + TimeValue.timeValueMinutes(15), + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + public static final Setting BACKOFF_INITIAL_DELAY = Setting + .positiveTimeSetting( + "opendistro.anomaly_detection.backoff_initial_delay", + TimeValue.timeValueMillis(1000), + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + public static final Setting MAX_RETRY_FOR_BACKOFF = Setting + .intSetting( + "opendistro.anomaly_detection.max_retry_for_backoff", + 3, + 0, + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + public static final Setting MAX_RETRY_FOR_END_RUN_EXCEPTION = Setting + .intSetting( + "opendistro.anomaly_detection.max_retry_for_end_run_exception", + 6, + 0, + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + public static final Setting FILTER_BY_BACKEND_ROLES = Setting + .boolSetting( + "opendistro.anomaly_detection.filter_by_backend_roles", + false, + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + // ====================================== + // ML parameters + // ====================================== + // RCF + public static final Setting MODEL_MAX_SIZE_PERCENTAGE = Setting + .doubleSetting( + "opendistro.anomaly_detection.model_max_size_percent", + 0.1, + 0, + 0.7, + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + // Increase the value will adding pressure to indexing anomaly results and our feature query + public static final Setting MAX_ENTITIES_PER_QUERY = Setting + .intSetting( + "opendistro.anomaly_detection.max_entities_per_query", + 1000, + 1, + 100_000_000, + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + // Default number of entities retrieved for Preview API + public static final int DEFAULT_ENTITIES_FOR_PREVIEW = 30; + + // Maximum number of entities retrieved for Preview API + public static final Setting MAX_ENTITIES_FOR_PREVIEW = Setting + .intSetting( + "opendistro.anomaly_detection.max_entities_for_preview", + DEFAULT_ENTITIES_FOR_PREVIEW, + 1, + 1000, + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + // save partial zero-anomaly grade results after indexing pressure reaching the limit + public static final Setting INDEX_PRESSURE_SOFT_LIMIT = Setting + .floatSetting( + "opendistro.anomaly_detection.index_pressure_soft_limit", + 0.8f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + // max number of primary shards of an AD index + public static final Setting MAX_PRIMARY_SHARDS = Setting + .intSetting( + "opendistro.anomaly_detection.max_primary_shards", + 10, + 0, + 200, + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + // responding to 100 cache misses per second allowed. + // 100 because the get threadpool (the one we need to get checkpoint) queue szie is 1000 + // and we may have 10 concurrent multi-entity detectors. So each detector can use: 1000 / 10 = 100 + // for 1m interval. if the max entity number is 3000 per node, it will need around 30m to get all of them cached + // Thus, for 5m internval, it will need 2.5 hours to cache all of them. for 1hour interval, it will be 30hours. + // but for 1 day interval, it will be 30 days. + public static Setting MAX_CACHE_MISS_HANDLING_PER_SECOND = Setting + .intSetting( + "opendistro.anomaly_detection.max_cache_miss_handling_per_second", + 100, + 0, + 1000, + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + // Maximum number of batch tasks running on one node. + // TODO: performance test and tune the setting. + public static final Setting MAX_BATCH_TASK_PER_NODE = Setting + .intSetting( + "opendistro.anomaly_detection.max_batch_task_per_node", + 10, + 1, + 100, + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + public static final Setting MAX_OLD_AD_TASK_DOCS_PER_DETECTOR = Setting + .intSetting( + "opendistro.anomaly_detection.max_old_ad_task_docs_per_detector", + // One AD task is roughly 1.5KB for normal case. Suppose task's size + // is 2KB conservatively. If we store 1000 AD tasks for one detector, + // and have 1000 detectors, that will be 2GB. + 1, // keep 1 old task by default to avoid consuming too much resource + 1, // keep at least 1 old AD task per detector + 1000, + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + public static final Setting BATCH_TASK_PIECE_SIZE = Setting + .intSetting( + "opendistro.anomaly_detection.batch_task_piece_size", + 1000, + 1, + 10_000, + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); + + public static final Setting BATCH_TASK_PIECE_INTERVAL_SECONDS = Setting + .intSetting( + "opendistro.anomaly_detection.batch_task_piece_interval_seconds", + 5, + 1, + 600, + Setting.Property.NodeScope, + Setting.Property.Dynamic, + Setting.Property.Deprecated + ); +} diff --git a/src/main/java/org/opensearch/ad/stats/ADStat.java-e b/src/main/java/org/opensearch/ad/stats/ADStat.java-e new file mode 100644 index 000000000..531205907 --- /dev/null +++ b/src/main/java/org/opensearch/ad/stats/ADStat.java-e @@ -0,0 +1,83 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.stats; + +import java.util.function.Supplier; + +import org.opensearch.ad.stats.suppliers.CounterSupplier; +import org.opensearch.ad.stats.suppliers.SettableSupplier; + +/** + * Class represents a stat the plugin keeps track of + */ +public class ADStat { + private Boolean clusterLevel; + private Supplier supplier; + + /** + * Constructor + * + * @param clusterLevel whether the stat has clusterLevel scope or nodeLevel scope + * @param supplier supplier that returns the stat's value + */ + public ADStat(Boolean clusterLevel, Supplier supplier) { + this.clusterLevel = clusterLevel; + this.supplier = supplier; + } + + /** + * Determines whether the stat is cluster specific or node specific + * + * @return true is stat is cluster level; false otherwise + */ + public Boolean isClusterLevel() { + return clusterLevel; + } + + /** + * Get the value of the statistic + * + * @return T value of the stat + */ + public T getValue() { + return supplier.get(); + } + + /** + * Set the value of the statistic + * + * @param value set value + */ + public void setValue(Long value) { + if (supplier instanceof SettableSupplier) { + ((SettableSupplier) supplier).set(value); + } + } + + /** + * Increments the supplier if it can be incremented + */ + public void increment() { + if (supplier instanceof CounterSupplier) { + ((CounterSupplier) supplier).increment(); + } + } + + /** + * Decrease the supplier if it can be decreased. + */ + public void decrement() { + if (supplier instanceof CounterSupplier) { + ((CounterSupplier) supplier).decrement(); + } + } +} diff --git a/src/main/java/org/opensearch/ad/stats/ADStats.java-e b/src/main/java/org/opensearch/ad/stats/ADStats.java-e new file mode 100644 index 000000000..1fb0e8fe4 --- /dev/null +++ b/src/main/java/org/opensearch/ad/stats/ADStats.java-e @@ -0,0 +1,84 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.stats; + +import java.util.HashMap; +import java.util.Map; + +/** + * This class is the main entry-point for access to the stats that the AD plugin keeps track of. + */ +public class ADStats { + + private Map> stats; + + /** + * Constructor + * + * @param stats Map of the stats that are to be kept + */ + public ADStats(Map> stats) { + this.stats = stats; + } + + /** + * Get the stats + * + * @return all of the stats + */ + public Map> getStats() { + return stats; + } + + /** + * Get individual stat by stat name + * + * @param key Name of stat + * @return ADStat + * @throws IllegalArgumentException thrown on illegal statName + */ + public ADStat getStat(String key) throws IllegalArgumentException { + if (!stats.keySet().contains(key)) { + throw new IllegalArgumentException("Stat=\"" + key + "\" does not exist"); + } + return stats.get(key); + } + + /** + * Get a map of the stats that are kept at the node level + * + * @return Map of stats kept at the node level + */ + public Map> getNodeStats() { + return getClusterOrNodeStats(false); + } + + /** + * Get a map of the stats that are kept at the cluster level + * + * @return Map of stats kept at the cluster level + */ + public Map> getClusterStats() { + return getClusterOrNodeStats(true); + } + + private Map> getClusterOrNodeStats(Boolean getClusterStats) { + Map> statsMap = new HashMap<>(); + + for (Map.Entry> entry : stats.entrySet()) { + if (entry.getValue().isClusterLevel() == getClusterStats) { + statsMap.put(entry.getKey(), entry.getValue()); + } + } + return statsMap; + } +} diff --git a/src/main/java/org/opensearch/ad/stats/ADStatsResponse.java b/src/main/java/org/opensearch/ad/stats/ADStatsResponse.java index fe3e82a48..f90e451f9 100644 --- a/src/main/java/org/opensearch/ad/stats/ADStatsResponse.java +++ b/src/main/java/org/opensearch/ad/stats/ADStatsResponse.java @@ -19,8 +19,8 @@ import org.apache.commons.lang.builder.ToStringBuilder; import org.opensearch.ad.model.Mergeable; import org.opensearch.ad.transport.ADStatsNodesResponse; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/stats/ADStatsResponse.java-e b/src/main/java/org/opensearch/ad/stats/ADStatsResponse.java-e new file mode 100644 index 000000000..f90e451f9 --- /dev/null +++ b/src/main/java/org/opensearch/ad/stats/ADStatsResponse.java-e @@ -0,0 +1,149 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.stats; + +import java.io.IOException; +import java.util.Map; + +import org.apache.commons.lang.builder.EqualsBuilder; +import org.apache.commons.lang.builder.HashCodeBuilder; +import org.apache.commons.lang.builder.ToStringBuilder; +import org.opensearch.ad.model.Mergeable; +import org.opensearch.ad.transport.ADStatsNodesResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +/** + * ADStatsResponse contains logic to merge the node stats and cluster stats together and return them to user + */ +public class ADStatsResponse implements ToXContentObject, Mergeable { + private ADStatsNodesResponse adStatsNodesResponse; + private Map clusterStats; + + /** + * Get cluster stats + * + * @return Map of cluster stats + */ + public Map getClusterStats() { + return clusterStats; + } + + /** + * Set cluster stats + * + * @param clusterStats Map of cluster stats + */ + public void setClusterStats(Map clusterStats) { + this.clusterStats = clusterStats; + } + + /** + * Get cluster stats + * + * @return ADStatsNodesResponse + */ + public ADStatsNodesResponse getADStatsNodesResponse() { + return adStatsNodesResponse; + } + + /** + * Sets adStatsNodesResponse + * + * @param adStatsNodesResponse AD Stats Response from Nodes + */ + public void setADStatsNodesResponse(ADStatsNodesResponse adStatsNodesResponse) { + this.adStatsNodesResponse = adStatsNodesResponse; + } + + /** + * Convert ADStatsResponse to XContent + * + * @param builder XContentBuilder + * @return XContentBuilder + * @throws IOException thrown on invalid input + */ + public XContentBuilder toXContent(XContentBuilder builder) throws IOException { + return toXContent(builder, ToXContent.EMPTY_PARAMS); + } + + public ADStatsResponse() {} + + public ADStatsResponse(StreamInput in) throws IOException { + adStatsNodesResponse = new ADStatsNodesResponse(in); + clusterStats = in.readMap(); + } + + public void writeTo(StreamOutput out) throws IOException { + adStatsNodesResponse.writeTo(out); + out.writeMap(clusterStats); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + for (Map.Entry clusterStat : clusterStats.entrySet()) { + builder.field(clusterStat.getKey(), clusterStat.getValue()); + } + adStatsNodesResponse.toXContent(xContentBuilder, ToXContent.EMPTY_PARAMS); + return xContentBuilder.endObject(); + } + + @Override + public void merge(Mergeable other) { + if (this == other || other == null || getClass() != other.getClass()) { + return; + } + + ADStatsResponse otherResponse = (ADStatsResponse) other; + + if (otherResponse.adStatsNodesResponse != null) { + this.adStatsNodesResponse = otherResponse.adStatsNodesResponse; + } + + if (otherResponse.clusterStats != null) { + this.clusterStats = otherResponse.clusterStats; + } + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + + ADStatsResponse other = (ADStatsResponse) obj; + return new EqualsBuilder() + .append(adStatsNodesResponse, other.adStatsNodesResponse) + .append(clusterStats, other.clusterStats) + .isEquals(); + } + + @Override + public int hashCode() { + return new HashCodeBuilder().append(adStatsNodesResponse).append(clusterStats).toHashCode(); + } + + @Override + public String toString() { + return new ToStringBuilder(this) + .append("adStatsNodesResponse", adStatsNodesResponse) + .append("clusterStats", clusterStats) + .toString(); + } +} diff --git a/src/main/java/org/opensearch/ad/stats/InternalStatNames.java-e b/src/main/java/org/opensearch/ad/stats/InternalStatNames.java-e new file mode 100644 index 000000000..56ff012a5 --- /dev/null +++ b/src/main/java/org/opensearch/ad/stats/InternalStatNames.java-e @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.stats; + +/** + * Enum containing names of all internal stats which will not be returned + * in AD stats REST API. + */ +public enum InternalStatNames { + JVM_HEAP_USAGE("jvm_heap_usage"), + AD_USED_BATCH_TASK_SLOT_COUNT("ad_used_batch_task_slot_count"), + AD_DETECTOR_ASSIGNED_BATCH_TASK_SLOT_COUNT("ad_detector_assigned_batch_task_slot_count"); + + private String name; + + InternalStatNames(String name) { + this.name = name; + } + + /** + * Get internal stat name + * + * @return name + */ + public String getName() { + return name; + } +} diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/CounterSupplier.java-e b/src/main/java/org/opensearch/ad/stats/suppliers/CounterSupplier.java-e new file mode 100644 index 000000000..39acd94ff --- /dev/null +++ b/src/main/java/org/opensearch/ad/stats/suppliers/CounterSupplier.java-e @@ -0,0 +1,48 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.stats.suppliers; + +import java.util.concurrent.atomic.LongAdder; +import java.util.function.Supplier; + +/** + * CounterSupplier provides a stateful count as the value + */ +public class CounterSupplier implements Supplier { + private LongAdder counter; + + /** + * Constructor + */ + public CounterSupplier() { + this.counter = new LongAdder(); + } + + @Override + public Long get() { + return counter.longValue(); + } + + /** + * Increments the value of the counter by 1 + */ + public void increment() { + counter.increment(); + } + + /** + * Decrease the value of the counter by 1 + */ + public void decrement() { + counter.decrement(); + } +} diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/IndexStatusSupplier.java-e b/src/main/java/org/opensearch/ad/stats/suppliers/IndexStatusSupplier.java-e new file mode 100644 index 000000000..ab9177cb5 --- /dev/null +++ b/src/main/java/org/opensearch/ad/stats/suppliers/IndexStatusSupplier.java-e @@ -0,0 +1,47 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.stats.suppliers; + +import java.util.function.Supplier; + +import org.opensearch.ad.util.IndexUtils; + +/** + * IndexStatusSupplier provides the status of an index as the value + */ +public class IndexStatusSupplier implements Supplier { + private IndexUtils indexUtils; + private String indexName; + + public static final String UNABLE_TO_RETRIEVE_HEALTH_MESSAGE = "unable to retrieve health"; + + /** + * Constructor + * + * @param indexUtils Utility for getting information about indices + * @param indexName Name of index to extract stats from + */ + public IndexStatusSupplier(IndexUtils indexUtils, String indexName) { + this.indexUtils = indexUtils; + this.indexName = indexName; + } + + @Override + public String get() { + try { + return indexUtils.getIndexHealthStatus(indexName); + } catch (IllegalArgumentException e) { + return UNABLE_TO_RETRIEVE_HEALTH_MESSAGE; + } + + } +} diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeCountSupplier.java-e b/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeCountSupplier.java-e new file mode 100644 index 000000000..8fdac74d7 --- /dev/null +++ b/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeCountSupplier.java-e @@ -0,0 +1,42 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.stats.suppliers; + +import java.util.function.Supplier; +import java.util.stream.Stream; + +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.ml.ModelManager; + +/** + * ModelsOnNodeCountSupplier provides the number of models a node contains + */ +public class ModelsOnNodeCountSupplier implements Supplier { + private ModelManager modelManager; + private CacheProvider cache; + + /** + * Constructor + * + * @param modelManager object that manages the model partitions hosted on the node + * @param cache object that manages multi-entity detectors' models + */ + public ModelsOnNodeCountSupplier(ModelManager modelManager, CacheProvider cache) { + this.modelManager = modelManager; + this.cache = cache; + } + + @Override + public Long get() { + return Stream.concat(modelManager.getAllModels().stream(), cache.get().getAllModels().stream()).count(); + } +} diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplier.java-e b/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplier.java-e new file mode 100644 index 000000000..3f5421032 --- /dev/null +++ b/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplier.java-e @@ -0,0 +1,95 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.stats.suppliers; + +import static org.opensearch.ad.ml.ModelState.LAST_CHECKPOINT_TIME_KEY; +import static org.opensearch.ad.ml.ModelState.LAST_USED_TIME_KEY; +import static org.opensearch.ad.ml.ModelState.MODEL_TYPE_KEY; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_MODEL_SIZE_PER_NODE; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.timeseries.constant.CommonName; + +/** + * ModelsOnNodeSupplier provides a List of ModelStates info for the models the nodes contains + */ +public class ModelsOnNodeSupplier implements Supplier>> { + private ModelManager modelManager; + private CacheProvider cache; + // the max number of models to return per node. Defaults to 100. + private volatile int numModelsToReturn; + + /** + * Set that contains the model stats that should be exposed. + */ + public static Set MODEL_STATE_STAT_KEYS = new HashSet<>( + Arrays + .asList( + CommonName.MODEL_ID_FIELD, + ADCommonName.DETECTOR_ID_KEY, + MODEL_TYPE_KEY, + CommonName.ENTITY_KEY, + LAST_USED_TIME_KEY, + LAST_CHECKPOINT_TIME_KEY + ) + ); + + /** + * Constructor + * + * @param modelManager object that manages the model partitions hosted on the node + * @param cache object that manages multi-entity detectors' models + * @param settings node settings accessor + * @param clusterService Cluster service accessor + */ + public ModelsOnNodeSupplier(ModelManager modelManager, CacheProvider cache, Settings settings, ClusterService clusterService) { + this.modelManager = modelManager; + this.cache = cache; + this.numModelsToReturn = MAX_MODEL_SIZE_PER_NODE.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_MODEL_SIZE_PER_NODE, it -> this.numModelsToReturn = it); + } + + @Override + public List> get() { + List> values = new ArrayList<>(); + Stream + .concat(modelManager.getAllModels().stream(), cache.get().getAllModels().stream()) + .limit(numModelsToReturn) + .forEach( + modelState -> values + .add( + modelState + .getModelStateAsMap() + .entrySet() + .stream() + .filter(entry -> MODEL_STATE_STAT_KEYS.contains(entry.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)) + ) + ); + + return values; + } +} diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/SettableSupplier.java-e b/src/main/java/org/opensearch/ad/stats/suppliers/SettableSupplier.java-e new file mode 100644 index 000000000..b39ecdde5 --- /dev/null +++ b/src/main/java/org/opensearch/ad/stats/suppliers/SettableSupplier.java-e @@ -0,0 +1,43 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.stats.suppliers; + +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; + +/** + * SettableSupplier allows a user to set the value of the supplier to be returned + */ +public class SettableSupplier implements Supplier { + protected AtomicLong value; + + /** + * Constructor + */ + public SettableSupplier() { + this.value = new AtomicLong(0L); + } + + @Override + public Long get() { + return value.get(); + } + + /** + * Set value to be returned by get + * + * @param value to set + */ + public void set(Long value) { + this.value.set(value); + } +} diff --git a/src/main/java/org/opensearch/ad/task/ADBatchTaskCache.java-e b/src/main/java/org/opensearch/ad/task/ADBatchTaskCache.java-e new file mode 100644 index 000000000..00c574669 --- /dev/null +++ b/src/main/java/org/opensearch/ad/task/ADBatchTaskCache.java-e @@ -0,0 +1,145 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.task; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.NUM_MIN_SAMPLES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.NUM_TREES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.TIME_DECAY; + +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; + +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.timeseries.model.Entity; + +import com.amazon.randomcutforest.config.Precision; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * AD batch task cache which will mainly hold these for one task: + * 1. RCF + * 2. threshold model + * 3. shingle + * 4. training data + * 5. entity if task is for HC detector + */ +public class ADBatchTaskCache { + private final String detectorId; + private final String taskId; + private final String detectorTaskId; + private ThresholdedRandomCutForest rcfModel; + private boolean thresholdModelTrained; + private Deque>> shingle; + private AtomicInteger thresholdModelTrainingDataSize = new AtomicInteger(0); + private AtomicBoolean cancelled = new AtomicBoolean(false); + private AtomicLong cacheMemorySize = new AtomicLong(0); + private String cancelReason; + private String cancelledBy; + private Entity entity; + + protected ADBatchTaskCache(ADTask adTask) { + this.detectorId = adTask.getId(); + this.taskId = adTask.getTaskId(); + this.detectorTaskId = adTask.getDetectorLevelTaskId(); + this.entity = adTask.getEntity(); + + AnomalyDetector detector = adTask.getDetector(); + int numberOfTrees = NUM_TREES; + int shingleSize = detector.getShingleSize(); + this.shingle = new ArrayDeque<>(shingleSize); + int dimensions = detector.getShingleSize() * detector.getEnabledFeatureIds().size(); + + rcfModel = ThresholdedRandomCutForest + .builder() + .dimensions(dimensions) + .numberOfTrees(numberOfTrees) + .timeDecay(TIME_DECAY) + .sampleSize(NUM_SAMPLES_PER_TREE) + .outputAfter(NUM_MIN_SAMPLES) + .initialAcceptFraction(NUM_MIN_SAMPLES * 1.0d / NUM_SAMPLES_PER_TREE) + .parallelExecutionEnabled(false) + .compact(true) + .precision(Precision.FLOAT_32) + .boundingBoxCacheFraction(AnomalyDetectorSettings.BATCH_BOUNDING_BOX_CACHE_RATIO) + .shingleSize(shingleSize) + .anomalyRate(1 - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE) + .build(); + + this.thresholdModelTrained = false; + } + + protected String getId() { + return detectorId; + } + + protected String getTaskId() { + return taskId; + } + + protected String getDetectorTaskId() { + return detectorTaskId; + } + + protected ThresholdedRandomCutForest getTRcfModel() { + return rcfModel; + } + + protected Deque>> getShingle() { + return shingle; + } + + protected void setThresholdModelTrained(boolean thresholdModelTrained) { + this.thresholdModelTrained = thresholdModelTrained; + } + + protected boolean isThresholdModelTrained() { + return thresholdModelTrained; + } + + public AtomicInteger getThresholdModelTrainingDataSize() { + return thresholdModelTrainingDataSize; + } + + protected AtomicLong getCacheMemorySize() { + return cacheMemorySize; + } + + protected boolean isCancelled() { + return cancelled.get(); + } + + protected String getCancelReason() { + return cancelReason; + } + + protected String getCancelledBy() { + return cancelledBy; + } + + public Entity getEntity() { + return entity; + } + + protected void cancel(String reason, String userName) { + this.cancelled.compareAndSet(false, true); + this.cancelReason = reason; + this.cancelledBy = userName; + } +} diff --git a/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java b/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java index a30e3f14d..2140ecf10 100644 --- a/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java +++ b/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java @@ -11,7 +11,6 @@ package org.opensearch.ad.task; -import static org.opensearch.ad.AnomalyDetectorPlugin.AD_BATCH_TASK_THREAD_POOL_NAME; import static org.opensearch.ad.breaker.MemoryCircuitBreaker.DEFAULT_JVM_HEAP_USAGE_THRESHOLD; import static org.opensearch.ad.constant.ADCommonMessages.NO_ELIGIBLE_NODE_TO_RUN_DETECTOR; import static org.opensearch.ad.model.ADTask.CURRENT_PIECE_FIELD; @@ -28,6 +27,7 @@ import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_TOP_ENTITIES_LIMIT_FOR_HISTORICAL_ANALYSIS; import static org.opensearch.ad.settings.AnomalyDetectorSettings.NUM_MIN_SAMPLES; import static org.opensearch.ad.stats.InternalStatNames.JVM_HEAP_USAGE; +import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_BATCH_TASK_THREAD_POOL_NAME; import static org.opensearch.timeseries.stats.StatNames.AD_EXECUTING_BATCH_TASK_COUNT; import static org.opensearch.timeseries.util.ParseUtils.isNullOrEmpty; diff --git a/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java-e b/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java-e new file mode 100644 index 000000000..2140ecf10 --- /dev/null +++ b/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java-e @@ -0,0 +1,1392 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.task; + +import static org.opensearch.ad.breaker.MemoryCircuitBreaker.DEFAULT_JVM_HEAP_USAGE_THRESHOLD; +import static org.opensearch.ad.constant.ADCommonMessages.NO_ELIGIBLE_NODE_TO_RUN_DETECTOR; +import static org.opensearch.ad.model.ADTask.CURRENT_PIECE_FIELD; +import static org.opensearch.ad.model.ADTask.EXECUTION_END_TIME_FIELD; +import static org.opensearch.ad.model.ADTask.INIT_PROGRESS_FIELD; +import static org.opensearch.ad.model.ADTask.STATE_FIELD; +import static org.opensearch.ad.model.ADTask.TASK_PROGRESS_FIELD; +import static org.opensearch.ad.model.ADTask.WORKER_NODE_FIELD; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_TOP_ENTITIES_FOR_HISTORICAL_ANALYSIS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_TOP_ENTITIES_LIMIT_FOR_HISTORICAL_ANALYSIS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.NUM_MIN_SAMPLES; +import static org.opensearch.ad.stats.InternalStatNames.JVM_HEAP_USAGE; +import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_BATCH_TASK_THREAD_POOL_NAME; +import static org.opensearch.timeseries.stats.StatNames.AD_EXECUTING_BATCH_TASK_COUNT; +import static org.opensearch.timeseries.util.ParseUtils.isNullOrEmpty; + +import java.time.Clock; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Deque; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.function.BiConsumer; +import java.util.stream.Collectors; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionListenerResponseHandler; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ThreadedActionListener; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.caching.PriorityTracker; +import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.feature.SinglePointFeatures; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskState; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.stats.ADStats; +import org.opensearch.ad.transport.ADBatchAnomalyResultRequest; +import org.opensearch.ad.transport.ADBatchAnomalyResultResponse; +import org.opensearch.ad.transport.ADBatchTaskRemoteExecutionAction; +import org.opensearch.ad.transport.ADStatsNodeResponse; +import org.opensearch.ad.transport.ADStatsNodesAction; +import org.opensearch.ad.transport.ADStatsRequest; +import org.opensearch.ad.transport.handler.AnomalyResultBulkIndexHandler; +import org.opensearch.ad.util.ExceptionUtil; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.CheckedRunnable; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.InjectSecurity; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.bucket.terms.StringTerms; +import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.opensearch.search.aggregations.metrics.InternalMax; +import org.opensearch.search.aggregations.metrics.InternalMin; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TaskCancelledException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.transport.TransportRequestOptions; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.parkservices.AnomalyDescriptor; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +public class ADBatchTaskRunner { + private final Logger logger = LogManager.getLogger(ADBatchTaskRunner.class); + + private Settings settings; + private final ThreadPool threadPool; + private final Client client; + private final SecurityClientUtil clientUtil; + private final ADStats adStats; + private final ClusterService clusterService; + private final FeatureManager featureManager; + private final ADCircuitBreakerService adCircuitBreakerService; + private final ADTaskManager adTaskManager; + private final AnomalyResultBulkIndexHandler anomalyResultBulkIndexHandler; + private final ADIndexManagement anomalyDetectionIndices; + private final SearchFeatureDao searchFeatureDao; + + private final ADTaskCacheManager adTaskCacheManager; + private final TransportRequestOptions option; + private final HashRing hashRing; + private final ModelManager modelManager; + + private volatile Integer maxAdBatchTaskPerNode; + private volatile Integer pieceSize; + private volatile Integer pieceIntervalSeconds; + private volatile Integer maxTopEntitiesPerHcDetector; + private volatile Integer maxRunningEntitiesPerDetector; + + private static final int MAX_TOP_ENTITY_SEARCH_BUCKETS = 1000; + private static final int SLEEP_TIME_FOR_NEXT_ENTITY_TASK_IN_MILLIS = 2000; + + public ADBatchTaskRunner( + Settings settings, + ThreadPool threadPool, + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + ADCircuitBreakerService adCircuitBreakerService, + FeatureManager featureManager, + ADTaskManager adTaskManager, + ADIndexManagement anomalyDetectionIndices, + ADStats adStats, + AnomalyResultBulkIndexHandler anomalyResultBulkIndexHandler, + ADTaskCacheManager adTaskCacheManager, + SearchFeatureDao searchFeatureDao, + HashRing hashRing, + ModelManager modelManager + ) { + this.settings = settings; + this.threadPool = threadPool; + this.clusterService = clusterService; + this.client = client; + this.clientUtil = clientUtil; + this.anomalyResultBulkIndexHandler = anomalyResultBulkIndexHandler; + this.adStats = adStats; + this.adCircuitBreakerService = adCircuitBreakerService; + this.adTaskManager = adTaskManager; + this.featureManager = featureManager; + this.anomalyDetectionIndices = anomalyDetectionIndices; + + this.option = TransportRequestOptions + .builder() + .withType(TransportRequestOptions.Type.REG) + .withTimeout(AnomalyDetectorSettings.REQUEST_TIMEOUT.get(settings)) + .build(); + + this.adTaskCacheManager = adTaskCacheManager; + this.searchFeatureDao = searchFeatureDao; + this.hashRing = hashRing; + this.modelManager = modelManager; + + this.maxAdBatchTaskPerNode = MAX_BATCH_TASK_PER_NODE.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_BATCH_TASK_PER_NODE, it -> maxAdBatchTaskPerNode = it); + + this.pieceSize = BATCH_TASK_PIECE_SIZE.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(BATCH_TASK_PIECE_SIZE, it -> pieceSize = it); + + this.pieceIntervalSeconds = BATCH_TASK_PIECE_INTERVAL_SECONDS.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(BATCH_TASK_PIECE_INTERVAL_SECONDS, it -> pieceIntervalSeconds = it); + + this.maxTopEntitiesPerHcDetector = MAX_TOP_ENTITIES_FOR_HISTORICAL_ANALYSIS.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(MAX_TOP_ENTITIES_FOR_HISTORICAL_ANALYSIS, it -> maxTopEntitiesPerHcDetector = it); + + this.maxRunningEntitiesPerDetector = MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS, it -> maxRunningEntitiesPerDetector = it); + } + + /** + * Run AD task. + * 1. For HC detector, will get top entities first(initialize top entities). If top + * entities already initialized, will execute AD task directly. + * 2. For single entity detector, execute AD task directly. + * @param adTask single entity or HC detector task + * @param transportService transport service + * @param listener action listener + */ + public void run(ADTask adTask, TransportService transportService, ActionListener listener) { + boolean isHCDetector = adTask.getDetector().isHighCardinality(); + if (isHCDetector && !adTaskCacheManager.topEntityInited(adTask.getId())) { + // Initialize top entities for HC detector + threadPool.executor(AD_BATCH_TASK_THREAD_POOL_NAME).execute(() -> { + ActionListener hcDelegatedListener = getInternalHCDelegatedListener(adTask); + ActionListener topEntitiesListener = getTopEntitiesListener(adTask, transportService, hcDelegatedListener); + try { + getTopEntities(adTask, topEntitiesListener); + } catch (Exception e) { + topEntitiesListener.onFailure(e); + } + }); + listener.onResponse(new ADBatchAnomalyResultResponse(clusterService.localNode().getId(), false)); + } else { + // Execute AD task for single entity detector or HC detector which top entities initialized + forwardOrExecuteADTask(adTask, transportService, listener); + } + } + + private ActionListener getInternalHCDelegatedListener(ADTask adTask) { + return ActionListener + .wrap( + r -> logger.debug("[InternalHCDelegatedListener]: running task {} on nodeId {}", adTask.getTaskId(), r.getNodeId()), + e -> logger.error("[InternalHCDelegatedListener]: failed to run task", e) + ); + } + + /** + * Create internal action listener for HC task. The action listener will be used in + * {@link ADBatchTaskRunner#getTopEntities(ADTask, ActionListener)}. Will call + * listener's onResponse method when get top entities done. + * + * @param adTask AD task + * @param transportService transport service + * @param listener action listener + * @return action listener + */ + private ActionListener getTopEntitiesListener( + ADTask adTask, + TransportService transportService, + ActionListener listener + ) { + String taskId = adTask.getTaskId(); + String detectorId = adTask.getId(); + ActionListener actionListener = ActionListener.wrap(response -> { + adTaskCacheManager.setTopEntityInited(detectorId); + int totalEntities = adTaskCacheManager.getPendingEntityCount(detectorId); + logger.info("Total top entities: {} for detector {}, task {}", totalEntities, detectorId, taskId); + hashRing.getNodesWithSameLocalAdVersion(dataNodes -> { + int numberOfEligibleDataNodes = dataNodes.length; + // maxAdBatchTaskPerNode means how many task can run on per data node, which is hard limitation per node. + // maxRunningEntitiesPerDetector means how many entities can run per detector on whole cluster, which is + // soft limit to control how many entities to run in parallel per HC detector. + int maxRunningEntitiesLimit = Math + .min(totalEntities, Math.min(numberOfEligibleDataNodes * maxAdBatchTaskPerNode, maxRunningEntitiesPerDetector)); + adTaskCacheManager.setDetectorTaskLaneLimit(detectorId, maxRunningEntitiesLimit); + // scale down HC detector task slots in case there is less top entities in detection date range + int maxRunningEntities = Math.min(maxRunningEntitiesLimit, adTaskCacheManager.getDetectorTaskSlots(detectorId)); + logger + .debug( + "Calculate task lane for detector {}: totalEntities: {}, numberOfEligibleDataNodes: {}, maxAdBatchTaskPerNode: {}, " + + "maxRunningEntitiesPerDetector: {}, maxRunningEntities: {}, detectorTaskSlots: {}", + detectorId, + totalEntities, + numberOfEligibleDataNodes, + maxAdBatchTaskPerNode, + maxRunningEntitiesPerDetector, + maxRunningEntities, + adTaskCacheManager.getDetectorTaskSlots(detectorId) + ); + forwardOrExecuteADTask(adTask, transportService, listener); + // As we have started one entity task, need to minus 1 for max allowed running entities. + adTaskCacheManager.setAllowedRunningEntities(detectorId, maxRunningEntities - 1); + }, listener); + }, e -> { + logger.debug("Failed to run task " + taskId, e); + if (adTask.getTaskType().equals(ADTaskType.HISTORICAL_HC_DETECTOR.name())) { + adTaskManager.entityTaskDone(adTask, e, transportService); + } + listener.onFailure(e); + }); + ThreadedActionListener threadedActionListener = new ThreadedActionListener<>( + logger, + threadPool, + AD_BATCH_TASK_THREAD_POOL_NAME, + actionListener, + false + ); + return threadedActionListener; + } + + /** + * Get top entities for HC detector. Will use similar logic of realtime detector, + * but split the whole historical detection date range into limited number of + * buckets (1000 buckets by default). Will get top entities for each bucket, then + * put each bucket's top entities into {@link PriorityTracker} to track top + * entities dynamically. Once all buckets done, we can get finalized top entities + * in {@link PriorityTracker}. + * + * @param adTask AD task + * @param internalHCListener internal HC listener + */ + public void getTopEntities(ADTask adTask, ActionListener internalHCListener) { + getDateRangeOfSourceData(adTask, (dataStartTime, dataEndTime) -> { + PriorityTracker priorityTracker = new PriorityTracker( + Clock.systemUTC(), + adTask.getDetector().getIntervalInSeconds(), + adTask.getDetectionDateRange().getStartTime().toEpochMilli(), + MAX_TOP_ENTITIES_LIMIT_FOR_HISTORICAL_ANALYSIS + ); + long detectorInterval = adTask.getDetector().getIntervalInMilliseconds(); + logger + .debug( + "start to search top entities at {}, data start time: {}, data end time: {}, interval: {}", + System.currentTimeMillis(), + dataStartTime, + dataEndTime, + detectorInterval + ); + if (adTask.getDetector().hasMultipleCategories()) { + searchTopEntitiesForMultiCategoryHC( + adTask, + priorityTracker, + dataEndTime, + Math.max((dataEndTime - dataStartTime) / MAX_TOP_ENTITY_SEARCH_BUCKETS, detectorInterval), + dataStartTime, + dataStartTime + detectorInterval, + internalHCListener + ); + } else { + searchTopEntitiesForSingleCategoryHC( + adTask, + priorityTracker, + dataEndTime, + Math.max((dataEndTime - dataStartTime) / MAX_TOP_ENTITY_SEARCH_BUCKETS, detectorInterval), + dataStartTime, + dataStartTime + detectorInterval, + internalHCListener + ); + } + }, internalHCListener); + } + + private void searchTopEntitiesForMultiCategoryHC( + ADTask adTask, + PriorityTracker priorityTracker, + long detectionEndTime, + long bucketInterval, + long dataStartTime, + long dataEndTime, + ActionListener internalHCListener + ) { + checkIfADTaskCancelledAndCleanupCache(adTask); + ActionListener> topEntitiesListener = ActionListener.wrap(topEntities -> { + topEntities + .forEach(entity -> priorityTracker.updatePriority(adTaskManager.convertEntityToString(entity, adTask.getDetector()))); + + if (dataEndTime < detectionEndTime) { + searchTopEntitiesForMultiCategoryHC( + adTask, + priorityTracker, + detectionEndTime, + bucketInterval, + dataEndTime, + dataEndTime + bucketInterval, + internalHCListener + ); + } else { + logger.debug("finish searching top entities at " + System.currentTimeMillis()); + List topNEntities = priorityTracker.getTopNEntities(maxTopEntitiesPerHcDetector); + if (topNEntities.size() == 0) { + logger.error("There is no entity found for detector " + adTask.getId()); + internalHCListener.onFailure(new ResourceNotFoundException(adTask.getId(), "No entity found")); + return; + } + adTaskCacheManager.addPendingEntities(adTask.getId(), topNEntities); + adTaskCacheManager.setTopEntityCount(adTask.getId(), topNEntities.size()); + internalHCListener.onResponse("Get top entities done"); + } + }, e -> { + logger.error("Failed to get top entities for detector " + adTask.getId(), e); + internalHCListener.onFailure(e); + }); + int minimumDocCount = Math.max((int) (bucketInterval / adTask.getDetector().getIntervalInMilliseconds()) / 2, 1); + searchFeatureDao + .getHighestCountEntities( + adTask.getDetector(), + dataStartTime, + dataEndTime, + MAX_TOP_ENTITIES_LIMIT_FOR_HISTORICAL_ANALYSIS, + minimumDocCount, + MAX_TOP_ENTITIES_LIMIT_FOR_HISTORICAL_ANALYSIS, + topEntitiesListener + ); + } + + private void searchTopEntitiesForSingleCategoryHC( + ADTask adTask, + PriorityTracker priorityTracker, + long detectionEndTime, + long interval, + long dataStartTime, + long dataEndTime, + ActionListener internalHCListener + ) { + checkIfADTaskCancelledAndCleanupCache(adTask); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + RangeQueryBuilder rangeQueryBuilder = new RangeQueryBuilder(adTask.getDetector().getTimeField()) + .gte(dataStartTime) + .lte(dataEndTime) + .format("epoch_millis"); + boolQueryBuilder.filter(rangeQueryBuilder); + boolQueryBuilder.filter(adTask.getDetector().getFilterQuery()); + sourceBuilder.query(boolQueryBuilder); + + String topEntitiesAgg = "topEntities"; + AggregationBuilder aggregation = new TermsAggregationBuilder(topEntitiesAgg) + .field(adTask.getDetector().getCategoryFields().get(0)) + .size(MAX_TOP_ENTITIES_LIMIT_FOR_HISTORICAL_ANALYSIS); + sourceBuilder.aggregation(aggregation).size(0); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(sourceBuilder); + searchRequest.indices(adTask.getDetector().getIndices().toArray(new String[0])); + final ActionListener searchResponseListener = ActionListener.wrap(r -> { + StringTerms stringTerms = r.getAggregations().get(topEntitiesAgg); + List buckets = stringTerms.getBuckets(); + List topEntities = new ArrayList<>(); + for (StringTerms.Bucket bucket : buckets) { + String key = bucket.getKeyAsString(); + topEntities.add(key); + } + + topEntities.forEach(e -> priorityTracker.updatePriority(e)); + if (dataEndTime < detectionEndTime) { + searchTopEntitiesForSingleCategoryHC( + adTask, + priorityTracker, + detectionEndTime, + interval, + dataEndTime, + dataEndTime + interval, + internalHCListener + ); + } else { + logger.debug("finish searching top entities at " + System.currentTimeMillis()); + List topNEntities = priorityTracker.getTopNEntities(maxTopEntitiesPerHcDetector); + if (topNEntities.size() == 0) { + logger.error("There is no entity found for detector " + adTask.getId()); + internalHCListener.onFailure(new ResourceNotFoundException(adTask.getId(), "No entity found")); + return; + } + adTaskCacheManager.addPendingEntities(adTask.getId(), topNEntities); + adTaskCacheManager.setTopEntityCount(adTask.getId(), topNEntities.size()); + internalHCListener.onResponse("Get top entities done"); + } + }, e -> { + logger.error("Failed to get top entities for detector " + adTask.getId(), e); + internalHCListener.onFailure(e); + }); + // using the original context in listener as user roles have no permissions for internal operations like fetching a + // checkpoint + clientUtil + .asyncRequestWithInjectedSecurity( + searchRequest, + client::search, + // user is the one who started historical detector. Read AnomalyDetectorJobTransportAction.doExecute. + adTask.getUser(), + client, + searchResponseListener + ); + } + + /** + * Forward AD task to work node. + * 1. For HC detector, return directly if no more pending entity. Otherwise check if + * there is AD task created for this entity. If yes, just forward the entity task + * to worker node; otherwise, create entity task first, then forward. + * 2. For single entity detector, set task as INIT state and forward task to worker + * node. + * + * @param adTask AD task + * @param transportService transport service + * @param listener action listener + */ + public void forwardOrExecuteADTask( + ADTask adTask, + TransportService transportService, + ActionListener listener + ) { + try { + checkIfADTaskCancelledAndCleanupCache(adTask); + String detectorId = adTask.getId(); + AnomalyDetector detector = adTask.getDetector(); + boolean isHCDetector = detector.isHighCardinality(); + if (isHCDetector) { + String entityString = adTaskCacheManager.pollEntity(detectorId); + logger.debug("Start to run entity: {} of detector {}", entityString, detectorId); + if (entityString == null) { + listener.onResponse(new ADBatchAnomalyResultResponse(clusterService.localNode().getId(), false)); + return; + } + ActionListener wrappedListener = ActionListener.wrap(r -> logger.debug("Entity task created successfully"), e -> { + logger.error("Failed to start entity task for detector: {}, entity: {}", detectorId, entityString); + // If fail, move the entity into pending task queue + adTaskCacheManager.addPendingEntity(detectorId, entityString); + }); + // This is to handle retry case. To retry entity, we need to get the old entity task created before. + Entity entity = adTaskManager.parseEntityFromString(entityString, adTask); + String parentTaskId = adTask.getTaskType().equals(ADTaskType.HISTORICAL_HC_ENTITY.name()) + ? adTask.getParentTaskId() // For HISTORICAL_HC_ENTITY task, return its parent task id + : adTask.getTaskId(); // For HISTORICAL_HC_DETECTOR task, its task id is parent task id + adTaskManager + .getAndExecuteOnLatestADTask( + detectorId, + parentTaskId, + entity, + ImmutableList.of(ADTaskType.HISTORICAL_HC_ENTITY), + existingEntityTask -> { + if (existingEntityTask.isPresent()) { // retry failed entity caused by limit exceed exception + // TODO: if task failed due to limit exceed exception in half way, resume from the break point or just clear + // the + // old AD tasks and rerun it? Currently we just support rerunning task failed due to limit exceed exception + // before starting. + ADTask adEntityTask = existingEntityTask.get(); + logger + .debug( + "Rerun entity task for task id: {}, error of last run: {}", + adEntityTask.getTaskId(), + adEntityTask.getError() + ); + ActionListener workerNodeResponseListener = workerNodeResponseListener( + adEntityTask, + transportService, + listener + ); + forwardOrExecuteEntityTask(adEntityTask, transportService, workerNodeResponseListener); + } else { + logger.info("Create entity task for entity:{}", entityString); + Instant now = Instant.now(); + ADTask adEntityTask = new ADTask.Builder() + .detectorId(adTask.getId()) + .detector(detector) + .isLatest(true) + .taskType(ADTaskType.HISTORICAL_HC_ENTITY.name()) + .executionStartTime(now) + .taskProgress(0.0f) + .initProgress(0.0f) + .state(ADTaskState.INIT.name()) + .initProgress(0.0f) + .lastUpdateTime(now) + .startedBy(adTask.getStartedBy()) + .coordinatingNode(clusterService.localNode().getId()) + .detectionDateRange(adTask.getDetectionDateRange()) + .user(adTask.getUser()) + .entity(entity) + .parentTaskId(parentTaskId) + .build(); + adTaskManager.createADTaskDirectly(adEntityTask, r -> { + adEntityTask.setTaskId(r.getId()); + ActionListener workerNodeResponseListener = workerNodeResponseListener( + adEntityTask, + transportService, + listener + ); + forwardOrExecuteEntityTask(adEntityTask, transportService, workerNodeResponseListener); + }, wrappedListener); + } + }, + transportService, + false, + wrappedListener + ); + } else { + Map updatedFields = new HashMap<>(); + updatedFields.put(STATE_FIELD, ADTaskState.INIT.name()); + updatedFields.put(INIT_PROGRESS_FIELD, 0.0f); + ActionListener workerNodeResponseListener = workerNodeResponseListener( + adTask, + transportService, + listener + ); + adTaskManager + .updateADTask( + adTask.getTaskId(), + updatedFields, + ActionListener + .wrap( + r -> forwardOrExecuteEntityTask(adTask, transportService, workerNodeResponseListener), + e -> { workerNodeResponseListener.onFailure(e); } + ) + ); + } + } catch (Exception e) { + logger.error("Failed to forward or execute AD task " + adTask.getTaskId(), e); + listener.onFailure(e); + } + } + + /** + * Return delegated listener to listen to task execution response. After task + * dispatched to worker node, this listener will listen to response from + * worker node. + * + * @param adTask AD task + * @param transportService transport service + * @param listener action listener + * @return action listener + */ + private ActionListener workerNodeResponseListener( + ADTask adTask, + TransportService transportService, + ActionListener listener + ) { + ActionListener actionListener = ActionListener.wrap(r -> { + listener.onResponse(r); + if (adTask.isEntityTask()) { + // When reach this line, the entity task already been put into worker node's cache. + // Then it's safe to move entity from temp entities queue to running entities queue. + adTaskCacheManager.moveToRunningEntity(adTask.getId(), adTaskManager.convertEntityToString(adTask)); + } + startNewEntityTaskLane(adTask, transportService); + }, e -> { + logger.error("Failed to dispatch task to worker node, task id: " + adTask.getTaskId(), e); + listener.onFailure(e); + handleException(adTask, e); + + if (adTask.getDetector().isHighCardinality()) { + // Entity task done on worker node. Send entity task done message to coordinating node to poll next entity. + adTaskManager.entityTaskDone(adTask, e, transportService); + if (adTaskCacheManager.getAvailableNewEntityTaskLanes(adTask.getId()) > 0) { + // When reach this line, it means entity task failed to start on worker node + // Sleep some time before starting new task lane. + threadPool + .schedule( + () -> startNewEntityTaskLane(adTask, transportService), + TimeValue.timeValueSeconds(SLEEP_TIME_FOR_NEXT_ENTITY_TASK_IN_MILLIS), + AD_BATCH_TASK_THREAD_POOL_NAME + ); + } + } + }); + + ThreadedActionListener threadedActionListener = new ThreadedActionListener<>( + logger, + threadPool, + AD_BATCH_TASK_THREAD_POOL_NAME, + actionListener, + false + ); + return threadedActionListener; + } + + private void forwardOrExecuteEntityTask( + ADTask adTask, + TransportService transportService, + ActionListener workerNodeResponseListener + ) { + checkIfADTaskCancelledAndCleanupCache(adTask); + dispatchTask(adTask, ActionListener.wrap(node -> { + if (clusterService.localNode().getId().equals(node.getId())) { + // Execute batch task locally + startADBatchTaskOnWorkerNode(adTask, false, transportService, workerNodeResponseListener); + } else { + // Execute batch task remotely + transportService + .sendRequest( + node, + ADBatchTaskRemoteExecutionAction.NAME, + new ADBatchAnomalyResultRequest(adTask), + option, + new ActionListenerResponseHandler<>(workerNodeResponseListener, ADBatchAnomalyResultResponse::new) + ); + } + }, e -> workerNodeResponseListener.onFailure(e))); + } + + // start new entity task lane + private synchronized void startNewEntityTaskLane(ADTask adTask, TransportService transportService) { + if (adTask.getDetector().isHighCardinality() && adTaskCacheManager.getAndDecreaseEntityTaskLanes(adTask.getId()) > 0) { + logger.debug("start new task lane for detector {}", adTask.getId()); + forwardOrExecuteADTask(adTask, transportService, getInternalHCDelegatedListener(adTask)); + } + } + + private void dispatchTask(ADTask adTask, ActionListener listener) { + hashRing.getNodesWithSameLocalAdVersion(dataNodes -> { + ADStatsRequest adStatsRequest = new ADStatsRequest(dataNodes); + adStatsRequest.addAll(ImmutableSet.of(AD_EXECUTING_BATCH_TASK_COUNT.getName(), JVM_HEAP_USAGE.getName())); + + client.execute(ADStatsNodesAction.INSTANCE, adStatsRequest, ActionListener.wrap(adStatsResponse -> { + List candidateNodeResponse = adStatsResponse + .getNodes() + .stream() + .filter(stat -> (long) stat.getStatsMap().get(JVM_HEAP_USAGE.getName()) < DEFAULT_JVM_HEAP_USAGE_THRESHOLD) + .collect(Collectors.toList()); + + if (candidateNodeResponse.size() == 0) { + StringBuilder errorMessageBuilder = new StringBuilder("All nodes' memory usage exceeds limitation ") + .append(DEFAULT_JVM_HEAP_USAGE_THRESHOLD) + .append("%. ") + .append(NO_ELIGIBLE_NODE_TO_RUN_DETECTOR) + .append(adTask.getId()); + String errorMessage = errorMessageBuilder.toString(); + logger.warn(errorMessage + ", task id " + adTask.getTaskId() + ", " + adTask.getTaskType()); + listener.onFailure(new LimitExceededException(adTask.getId(), errorMessage)); + return; + } + candidateNodeResponse = candidateNodeResponse + .stream() + .filter(stat -> (Long) stat.getStatsMap().get(AD_EXECUTING_BATCH_TASK_COUNT.getName()) < maxAdBatchTaskPerNode) + .collect(Collectors.toList()); + if (candidateNodeResponse.size() == 0) { + StringBuilder errorMessageBuilder = new StringBuilder("All nodes' executing batch tasks exceeds limitation ") + .append(NO_ELIGIBLE_NODE_TO_RUN_DETECTOR) + .append(adTask.getId()); + String errorMessage = errorMessageBuilder.toString(); + logger.warn(errorMessage + ", task id " + adTask.getTaskId() + ", " + adTask.getTaskType()); + listener.onFailure(new LimitExceededException(adTask.getId(), errorMessage)); + return; + } + Optional targetNode = candidateNodeResponse + .stream() + .sorted((ADStatsNodeResponse r1, ADStatsNodeResponse r2) -> { + int result = ((Long) r1.getStatsMap().get(AD_EXECUTING_BATCH_TASK_COUNT.getName())) + .compareTo((Long) r2.getStatsMap().get(AD_EXECUTING_BATCH_TASK_COUNT.getName())); + if (result == 0) { + // if multiple nodes have same running task count, choose the one with least + // JVM heap usage. + return ((Long) r1.getStatsMap().get(JVM_HEAP_USAGE.getName())) + .compareTo((Long) r2.getStatsMap().get(JVM_HEAP_USAGE.getName())); + } + return result; + }) + .findFirst(); + listener.onResponse(targetNode.get().getNode()); + }, exception -> { + logger.error("Failed to get node's task stats", exception); + listener.onFailure(exception); + })); + }, listener); + } + + /** + * Start AD task in dedicated batch task thread pool on worker node. + * + * @param adTask ad task + * @param runTaskRemotely run task remotely or not + * @param transportService transport service + * @param delegatedListener action listener + */ + public void startADBatchTaskOnWorkerNode( + ADTask adTask, + boolean runTaskRemotely, + TransportService transportService, + ActionListener delegatedListener + ) { + try { + // check if cluster is eligible to run AD currently, if not eligible like + // circuit breaker open, will throw exception. + checkClusterState(adTask); + threadPool.executor(AD_BATCH_TASK_THREAD_POOL_NAME).execute(() -> { + ActionListener internalListener = internalBatchTaskListener(adTask, transportService); + try { + executeADBatchTaskOnWorkerNode(adTask, internalListener); + } catch (Exception e) { + internalListener.onFailure(e); + } + }); + delegatedListener.onResponse(new ADBatchAnomalyResultResponse(clusterService.localNode().getId(), runTaskRemotely)); + } catch (Exception e) { + logger.error("Fail to start AD batch task " + adTask.getTaskId(), e); + delegatedListener.onFailure(e); + } + } + + private ActionListener internalBatchTaskListener(ADTask adTask, TransportService transportService) { + String taskId = adTask.getTaskId(); + String detectorTaskId = adTask.getDetectorLevelTaskId(); + String detectorId = adTask.getId(); + ActionListener listener = ActionListener.wrap(response -> { + // If batch task finished normally, remove task from cache and decrease executing task count by 1. + adTaskCacheManager.remove(taskId, detectorId, detectorTaskId); + adStats.getStat(AD_EXECUTING_BATCH_TASK_COUNT.getName()).decrement(); + if (!adTask.getDetector().isHighCardinality()) { + // Set single-entity detector task as FINISHED here + adTaskManager + .cleanDetectorCache( + adTask, + transportService, + () -> adTaskManager.updateADTask(taskId, ImmutableMap.of(STATE_FIELD, ADTaskState.FINISHED.name())) + ); + } else { + // Set entity task as FINISHED here + adTaskManager.updateADTask(adTask.getTaskId(), ImmutableMap.of(STATE_FIELD, ADTaskState.FINISHED.name())); + adTaskManager.entityTaskDone(adTask, null, transportService); + } + }, e -> { + // If batch task failed, remove task from cache and decrease executing task count by 1. + adTaskCacheManager.remove(taskId, detectorId, detectorTaskId); + adStats.getStat(AD_EXECUTING_BATCH_TASK_COUNT.getName()).decrement(); + if (!adTask.getDetector().isHighCardinality()) { + adTaskManager.cleanDetectorCache(adTask, transportService, () -> handleException(adTask, e)); + } else { + adTaskManager.entityTaskDone(adTask, e, transportService); + handleException(adTask, e); + } + }); + ThreadedActionListener threadedActionListener = new ThreadedActionListener<>( + logger, + threadPool, + AD_BATCH_TASK_THREAD_POOL_NAME, + listener, + false + ); + return threadedActionListener; + } + + private void handleException(ADTask adTask, Exception e) { + // Check if batch task was cancelled or not by exception type. + // If it's cancelled, then increase cancelled task count by 1, otherwise increase failure count by 1. + if (e instanceof TaskCancelledException) { + adStats.getStat(StatNames.AD_CANCELED_BATCH_TASK_COUNT.getName()).increment(); + } else if (ExceptionUtil.countInStats(e)) { + adStats.getStat(StatNames.AD_BATCH_TASK_FAILURE_COUNT.getName()).increment(); + } + // Handle AD task exception + adTaskManager.handleADTaskException(adTask, e); + } + + private void executeADBatchTaskOnWorkerNode(ADTask adTask, ActionListener internalListener) { + // track AD executing batch task and total batch task execution count + adStats.getStat(AD_EXECUTING_BATCH_TASK_COUNT.getName()).increment(); + adStats.getStat(StatNames.AD_TOTAL_BATCH_TASK_EXECUTION_COUNT.getName()).increment(); + + // put AD task into cache + adTaskCacheManager.add(adTask); + + // start to run first piece + Instant executeStartTime = Instant.now(); + // TODO: refactor to make the workflow more clear + runFirstPiece(adTask, executeStartTime, internalListener); + } + + private void checkClusterState(ADTask adTask) { + // check if AD plugin is enabled + checkADPluginEnabled(adTask.getId()); + + // check if circuit breaker is open + checkCircuitBreaker(adTask); + } + + private void checkADPluginEnabled(String detectorId) { + if (!ADEnabledSetting.isADEnabled()) { + throw new EndRunException(detectorId, ADCommonMessages.DISABLED_ERR_MSG, true).countedInStats(false); + } + } + + private void checkCircuitBreaker(ADTask adTask) { + String taskId = adTask.getTaskId(); + if (adCircuitBreakerService.isOpen()) { + String error = "Circuit breaker is open"; + logger.error("AD task: {}, {}", taskId, error); + throw new LimitExceededException(adTask.getId(), error, true); + } + } + + private void runFirstPiece(ADTask adTask, Instant executeStartTime, ActionListener internalListener) { + try { + adTaskManager + .updateADTask( + adTask.getTaskId(), + ImmutableMap + .of( + STATE_FIELD, + ADTaskState.INIT.name(), + CURRENT_PIECE_FIELD, + adTask.getDetectionDateRange().getStartTime().toEpochMilli(), + TASK_PROGRESS_FIELD, + 0.0f, + INIT_PROGRESS_FIELD, + 0.0f, + WORKER_NODE_FIELD, + clusterService.localNode().getId() + ), + ActionListener.wrap(r -> { + try { + checkIfADTaskCancelledAndCleanupCache(adTask); + getDateRangeOfSourceData(adTask, (dataStartTime, dataEndTime) -> { + long interval = ((IntervalTimeConfiguration) adTask.getDetector().getInterval()).toDuration().toMillis(); + long expectedPieceEndTime = dataStartTime + pieceSize * interval; + long firstPieceEndTime = Math.min(expectedPieceEndTime, dataEndTime); + logger + .debug( + "start first piece from {} to {}, interval {}, dataStartTime {}, dataEndTime {}," + + " detectorId {}, taskId {}", + dataStartTime, + firstPieceEndTime, + interval, + dataStartTime, + dataEndTime, + adTask.getId(), + adTask.getTaskId() + ); + getFeatureData( + adTask, + dataStartTime, // first piece start time + firstPieceEndTime, // first piece end time + dataStartTime, + dataEndTime, + interval, + executeStartTime, + internalListener + ); + }, internalListener); + } catch (Exception e) { + internalListener.onFailure(e); + } + }, internalListener::onFailure) + ); + } catch (Exception exception) { + internalListener.onFailure(exception); + } + } + + private void getDateRangeOfSourceData(ADTask adTask, BiConsumer consumer, ActionListener internalListener) { + String taskId = adTask.getTaskId(); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .aggregation(AggregationBuilders.min(CommonName.AGG_NAME_MIN_TIME).field(adTask.getDetector().getTimeField())) + .aggregation(AggregationBuilders.max(CommonName.AGG_NAME_MAX_TIME).field(adTask.getDetector().getTimeField())) + .size(0); + if (adTask.getEntity() != null && adTask.getEntity().getAttributes().size() > 0) { + BoolQueryBuilder query = new BoolQueryBuilder(); + adTask + .getEntity() + .getAttributes() + .entrySet() + .forEach(entity -> query.filter(new TermQueryBuilder(entity.getKey(), entity.getValue()))); + searchSourceBuilder.query(query); + } + + SearchRequest request = new SearchRequest() + .indices(adTask.getDetector().getIndices().toArray(new String[0])) + .source(searchSourceBuilder); + final ActionListener searchResponseListener = ActionListener.wrap(r -> { + InternalMin minAgg = r.getAggregations().get(CommonName.AGG_NAME_MIN_TIME); + InternalMax maxAgg = r.getAggregations().get(CommonName.AGG_NAME_MAX_TIME); + double minValue = minAgg.getValue(); + double maxValue = maxAgg.getValue(); + // If time field not exist or there is no value, will return infinity value + if (minValue == Double.POSITIVE_INFINITY) { + internalListener.onFailure(new ResourceNotFoundException(adTask.getId(), "There is no data in the time field")); + return; + } + long interval = ((IntervalTimeConfiguration) adTask.getDetector().getInterval()).toDuration().toMillis(); + + DateRange detectionDateRange = adTask.getDetectionDateRange(); + long dataStartTime = detectionDateRange.getStartTime().toEpochMilli(); + long dataEndTime = detectionDateRange.getEndTime().toEpochMilli(); + long minDate = (long) minValue; + long maxDate = (long) maxValue; + + if (minDate >= dataEndTime || maxDate <= dataStartTime) { + internalListener.onFailure(new ResourceNotFoundException(adTask.getId(), "There is no data in the detection date range")); + return; + } + if (minDate > dataStartTime) { + dataStartTime = minDate; + } + if (maxDate < dataEndTime) { + dataEndTime = maxDate; + } + + // normalize start/end time to make it consistent with feature data agg result + dataStartTime = dataStartTime - dataStartTime % interval; + dataEndTime = dataEndTime - dataEndTime % interval; + logger.debug("adjusted date range: start: {}, end: {}, taskId: {}", dataStartTime, dataEndTime, taskId); + if ((dataEndTime - dataStartTime) < NUM_MIN_SAMPLES * interval) { + internalListener.onFailure(new TimeSeriesException("There is not enough data to train model").countedInStats(false)); + return; + } + consumer.accept(dataStartTime, dataEndTime); + }, e -> { internalListener.onFailure(e); }); + + // inject user role while searching. + clientUtil + .asyncRequestWithInjectedSecurity( + request, + client::search, + // user is the one who started historical detector. Read AnomalyDetectorJobTransportAction.doExecute. + adTask.getUser(), + client, + searchResponseListener + ); + } + + private void getFeatureData( + ADTask adTask, + long pieceStartTime, + long pieceEndTime, + long dataStartTime, + long dataEndTime, + long interval, + Instant executeStartTime, + ActionListener internalListener + ) { + ActionListener>> actionListener = ActionListener.wrap(dataPoints -> { + try { + if (dataPoints.size() == 0) { + logger.debug("No data in current piece with end time: " + pieceEndTime); + // Current piece end time is the next piece's start time + runNextPiece(adTask, pieceEndTime, dataStartTime, dataEndTime, interval, internalListener); + } else { + detectAnomaly( + adTask, + dataPoints, + pieceStartTime, + pieceEndTime, + dataStartTime, + dataEndTime, + interval, + executeStartTime, + internalListener + ); + } + } catch (Exception e) { + internalListener.onFailure(e); + } + }, exception -> { + logger.debug("Fail to get feature data by batch for this piece with end time: " + pieceEndTime); + // TODO: Exception may be caused by wrong feature query or some bad data. Differentiate these + // and skip current piece if error caused by bad data. + internalListener.onFailure(exception); + }); + ThreadedActionListener>> threadedActionListener = new ThreadedActionListener<>( + logger, + threadPool, + AD_BATCH_TASK_THREAD_POOL_NAME, + actionListener, + false + ); + + featureManager + .getFeatureDataPointsByBatch(adTask.getDetector(), adTask.getEntity(), pieceStartTime, pieceEndTime, threadedActionListener); + } + + private void detectAnomaly( + ADTask adTask, + Map> dataPoints, + long pieceStartTime, + long pieceEndTime, + long dataStartTime, + long dataEndTime, + long interval, + Instant executeStartTime, + ActionListener internalListener + ) { + String taskId = adTask.getTaskId(); + ThresholdedRandomCutForest trcf = adTaskCacheManager.getTRcfModel(taskId); + Deque>> shingle = adTaskCacheManager.getShingle(taskId); + + List anomalyResults = new ArrayList<>(); + + long intervalEndTime = pieceStartTime; + for (int i = 0; i < pieceSize && intervalEndTime < dataEndTime; i++) { + Optional dataPoint = dataPoints.containsKey(intervalEndTime) ? dataPoints.get(intervalEndTime) : Optional.empty(); + intervalEndTime = intervalEndTime + interval; + SinglePointFeatures feature = featureManager + .getShingledFeatureForHistoricalAnalysis(adTask.getDetector(), shingle, dataPoint, intervalEndTime); + List featureData = null; + if (feature.getUnprocessedFeatures().isPresent()) { + featureData = ParseUtils.getFeatureData(feature.getUnprocessedFeatures().get(), adTask.getDetector()); + } + if (!feature.getProcessedFeatures().isPresent()) { + String error = feature.getUnprocessedFeatures().isPresent() + ? "No full shingle in current detection window" + : "No data in current detection window"; + AnomalyResult anomalyResult = new AnomalyResult( + adTask.getId(), + adTask.getDetectorLevelTaskId(), + featureData, + Instant.ofEpochMilli(intervalEndTime - interval), + Instant.ofEpochMilli(intervalEndTime), + executeStartTime, + Instant.now(), + error, + Optional.ofNullable(adTask.getEntity()), + adTask.getDetector().getUser(), + anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT), + adTask.getEntityModelId() + ); + anomalyResults.add(anomalyResult); + } else { + double[] point = feature.getProcessedFeatures().get(); + // 0 is placeholder for timestamp. In the future, we will add + // data time stamp there. + AnomalyDescriptor descriptor = trcf.process(point, 0); + double score = descriptor.getRCFScore(); + if (!adTaskCacheManager.isThresholdModelTrained(taskId) && score > 0) { + adTaskCacheManager.setThresholdModelTrained(taskId, true); + } + + AnomalyResult anomalyResult = AnomalyResult + .fromRawTRCFResult( + adTask.getId(), + adTask.getDetector().getIntervalInMilliseconds(), + adTask.getDetectorLevelTaskId(), + score, + descriptor.getAnomalyGrade(), + descriptor.getDataConfidence(), + featureData, + Instant.ofEpochMilli(intervalEndTime - interval), + Instant.ofEpochMilli(intervalEndTime), + executeStartTime, + Instant.now(), + null, + Optional.ofNullable(adTask.getEntity()), + adTask.getDetector().getUser(), + anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT), + adTask.getEntityModelId(), + modelManager.normalizeAttribution(trcf.getForest(), descriptor.getRelevantAttribution()), + descriptor.getRelativeIndex(), + descriptor.getPastValues(), + descriptor.getExpectedValuesList(), + descriptor.getLikelihoodOfValues(), + descriptor.getThreshold() + ); + anomalyResults.add(anomalyResult); + } + } + + String user; + List roles; + if (adTask.getUser() == null) { + // It's possible that user create domain with security disabled, then enable security + // after upgrading. This is for BWC, for old detectors which created when security + // disabled, the user will be null. + user = ""; + roles = settings.getAsList("", ImmutableList.of("all_access", "AmazonES_all_access")); + } else { + user = adTask.getUser().getName(); + roles = adTask.getUser().getRoles(); + } + String resultIndex = adTask.getDetector().getCustomResultIndex(); + + if (resultIndex == null) { + // if result index is null, store anomaly result directly + storeAnomalyResultAndRunNextPiece( + adTask, + pieceEndTime, + dataStartTime, + dataEndTime, + interval, + internalListener, + anomalyResults, + resultIndex, + null + ); + return; + } + + try (InjectSecurity injectSecurity = new InjectSecurity(adTask.getTaskId(), settings, client.threadPool().getThreadContext())) { + // Injecting user role to verify if the user has permissions to write result to result index. + injectSecurity.inject(user, roles); + storeAnomalyResultAndRunNextPiece( + adTask, + pieceEndTime, + dataStartTime, + dataEndTime, + interval, + internalListener, + anomalyResults, + resultIndex, + () -> injectSecurity.close() + ); + } catch (Exception exception) { + logger.error("Failed to inject user roles", exception); + internalListener.onFailure(exception); + } + } + + private void storeAnomalyResultAndRunNextPiece( + ADTask adTask, + long pieceEndTime, + long dataStartTime, + long dataEndTime, + long interval, + ActionListener internalListener, + List anomalyResults, + String resultIndex, + CheckedRunnable runBefore + ) { + ActionListener actionListener = new ThreadedActionListener<>( + logger, + threadPool, + AD_BATCH_TASK_THREAD_POOL_NAME, + ActionListener.wrap(r -> { + try { + runNextPiece(adTask, pieceEndTime, dataStartTime, dataEndTime, interval, internalListener); + } catch (Exception e) { + internalListener.onFailure(e); + } + }, e -> { + logger.error("Fail to bulk index anomaly result", e); + internalListener.onFailure(e); + }), + false + ); + + anomalyResultBulkIndexHandler + .bulkIndexAnomalyResult( + resultIndex, + anomalyResults, + runBefore == null ? actionListener : ActionListener.runBefore(actionListener, runBefore) + ); + } + + private void runNextPiece( + ADTask adTask, + long pieceStartTime, + long dataStartTime, + long dataEndTime, + long interval, + ActionListener internalListener + ) { + String taskId = adTask.getTaskId(); + String detectorId = adTask.getId(); + String detectorTaskId = adTask.getDetectorLevelTaskId(); + float initProgress = calculateInitProgress(taskId); + String taskState = initProgress >= 1.0f ? ADTaskState.RUNNING.name() : ADTaskState.INIT.name(); + logger.debug("Init progress: {}, taskState:{}, task id: {}", initProgress, taskState, taskId); + + if (initProgress >= 1.0f && adTask.isEntityTask()) { + updateDetectorLevelTaskState(detectorId, adTask.getParentTaskId(), ADTaskState.RUNNING.name()); + } + + if (pieceStartTime < dataEndTime) { + checkIfADTaskCancelledAndCleanupCache(adTask); + threadPool.schedule(() -> { + checkClusterState(adTask); + long expectedPieceEndTime = pieceStartTime + pieceSize * interval; + long pieceEndTime = expectedPieceEndTime > dataEndTime ? dataEndTime : expectedPieceEndTime; + logger + .debug( + "task id: {}, start next piece start from {} to {}, interval {}", + adTask.getTaskId(), + pieceStartTime, + pieceEndTime, + interval + ); + float taskProgress = (float) (pieceStartTime - dataStartTime) / (dataEndTime - dataStartTime); + logger.debug("Task progress: {}, task id:{}, detector id:{}", taskProgress, taskId, detectorId); + adTaskManager + .updateADTask( + taskId, + ImmutableMap + .of( + STATE_FIELD, + taskState, + CURRENT_PIECE_FIELD, + pieceStartTime, + TASK_PROGRESS_FIELD, + taskProgress, + INIT_PROGRESS_FIELD, + initProgress + ), + ActionListener + .wrap( + r -> getFeatureData( + adTask, + pieceStartTime, + pieceEndTime, + dataStartTime, + dataEndTime, + interval, + Instant.now(), + internalListener + ), + e -> internalListener.onFailure(e) + ) + ); + }, TimeValue.timeValueSeconds(pieceIntervalSeconds), AD_BATCH_TASK_THREAD_POOL_NAME); + } else { + logger.info("AD task finished for detector {}, task id: {}", detectorId, taskId); + adTaskCacheManager.remove(taskId, detectorId, detectorTaskId); + adTaskManager + .updateADTask( + taskId, + ImmutableMap + .of( + CURRENT_PIECE_FIELD, + dataEndTime, + TASK_PROGRESS_FIELD, + 1.0f, + EXECUTION_END_TIME_FIELD, + Instant.now().toEpochMilli(), + INIT_PROGRESS_FIELD, + initProgress, + STATE_FIELD, + ADTaskState.FINISHED + ), + ActionListener.wrap(r -> internalListener.onResponse("task execution done"), e -> internalListener.onFailure(e)) + ); + } + } + + private void updateDetectorLevelTaskState(String detectorId, String detectorTaskId, String newState) { + ExecutorFunction function = () -> adTaskManager + .updateADTask(detectorTaskId, ImmutableMap.of(STATE_FIELD, newState), ActionListener.wrap(r -> { + logger.info("Updated HC detector task: {} state as: {} for detector: {}", detectorTaskId, newState, detectorId); + adTaskCacheManager.updateDetectorTaskState(detectorId, detectorTaskId, newState); + }, e -> { logger.error("Failed to update HC detector task: {} for detector: {}", detectorTaskId, detectorId); })); + if (adTaskCacheManager.detectorTaskStateExists(detectorId, detectorTaskId)) { + if (!Objects.equals(adTaskCacheManager.getDetectorTaskState(detectorId, detectorTaskId), newState)) { + function.execute(); + } + } else if (!adTaskCacheManager.isHistoricalAnalysisCancelledForHC(detectorId, detectorTaskId)) { + adTaskManager.getADTask(detectorTaskId, ActionListener.wrap(task -> { + if (task.isPresent()) { + if (!Objects.equals(task.get().getState(), newState)) { + function.execute(); + } + } + }, exception -> { logger.error("failed to get detector level task " + detectorTaskId, exception); })); + } + } + + private float calculateInitProgress(String taskId) { + RandomCutForest rcf = adTaskCacheManager.getTRcfModel(taskId).getForest(); + if (rcf == null) { + return 0.0f; + } + float initProgress = (float) rcf.getTotalUpdates() / NUM_MIN_SAMPLES; + logger.debug("RCF total updates {} for task {}", rcf.getTotalUpdates(), taskId); + return initProgress > 1.0f ? 1.0f : initProgress; + } + + private void checkIfADTaskCancelledAndCleanupCache(ADTask adTask) { + String taskId = adTask.getTaskId(); + String detectorId = adTask.getId(); + String detectorTaskId = adTask.getDetectorLevelTaskId(); + // refresh latest HC task run time + adTaskCacheManager.refreshLatestHCTaskRunTime(detectorId); + if (adTask.getDetector().isHighCardinality() + && adTaskCacheManager.isHCTaskCoordinatingNode(detectorId) + && adTaskCacheManager.isHistoricalAnalysisCancelledForHC(detectorId, detectorTaskId)) { + // clean up pending and running entity on coordinating node + adTaskCacheManager.clearPendingEntities(detectorId); + adTaskCacheManager.removeRunningEntity(detectorId, adTaskManager.convertEntityToString(adTask)); + throw new TaskCancelledException( + adTaskCacheManager.getCancelReasonForHC(detectorId, detectorTaskId), + adTaskCacheManager.getCancelledByForHC(detectorId, detectorTaskId) + ); + } + + if (adTaskCacheManager.contains(taskId) && adTaskCacheManager.isCancelled(taskId)) { + logger.info("AD task cancelled, stop running task {}", taskId); + String cancelReason = adTaskCacheManager.getCancelReason(taskId); + String cancelledBy = adTaskCacheManager.getCancelledBy(taskId); + adTaskCacheManager.remove(taskId, detectorId, detectorTaskId); + if (!adTaskCacheManager.isHCTaskCoordinatingNode(detectorId) + && isNullOrEmpty(adTaskCacheManager.getTasksOfDetector(detectorId))) { + // Clean up historical task cache for HC detector on worker node if no running entity task. + logger.info("All AD task cancelled, cleanup historical task cache for detector {}", detectorId); + adTaskCacheManager.removeHistoricalTaskCache(detectorId); + } + + throw new TaskCancelledException(cancelReason, cancelledBy); + } + } + +} diff --git a/src/main/java/org/opensearch/ad/task/ADHCBatchTaskCache.java-e b/src/main/java/org/opensearch/ad/task/ADHCBatchTaskCache.java-e new file mode 100644 index 000000000..3ad2a91c3 --- /dev/null +++ b/src/main/java/org/opensearch/ad/task/ADHCBatchTaskCache.java-e @@ -0,0 +1,286 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.task; + +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * AD HC detector batch task cache which will mainly hold these for HC detector on + * coordinating node. + *
    + *
  • pending entities queue
  • + *
  • running entities queue
  • + *
  • temp entities queue
  • + *
  • entity task lanes
  • + *
  • top entities count
  • + *
  • top entities inited or not
  • + *
  • task retry times
  • + *
  • detector task update semaphore to control only 1 thread update detector level task
  • + *
+ */ +public class ADHCBatchTaskCache { + + // Cache pending entities. + private Queue pendingEntities; + + // Cache running entities. + private Queue runningEntities; + + // Will move entity from pending queue to this temp queue once task dispatched to work node. + // If fail to dispatch to work node, will move entity from temp queue to pending queue. + // If work node returns response successfully, will move entity from temp queue to running queue. + // If we just move entity from pending queue to running queue directly, the running queue can't + // match the real running task on worker nodes. + private Queue tempEntities; + + // How many entity task lanes can run concurrently. One entity task lane can run one entity task. + // Entity lane is a virtual concept which represents one running entity task. + private AtomicInteger entityTaskLanes; + + // How many top entities totally for this HC task. + // Will calculate HC task progress with it and profile API needs this. + private Integer topEntityCount; + + // This is to control only one entity task updating detector level task. + private Semaphore detectorTaskUpdatingSemaphore; + + // Top entities inited or not. + private Boolean topEntitiesInited; + + // Record how many times the task has retried. Key is task id. + private Map taskRetryTimes; + + // record last time when HC detector scales entity task slots + private Instant lastScaleEntityTaskSlotsTime; + + // record lastest HC detector task run time, will use this field to check if task is running or not. + private Instant latestTaskRunTime; + + public ADHCBatchTaskCache() { + this.pendingEntities = new ConcurrentLinkedQueue<>(); + this.runningEntities = new ConcurrentLinkedQueue<>(); + this.tempEntities = new ConcurrentLinkedQueue<>(); + this.taskRetryTimes = new ConcurrentHashMap<>(); + this.detectorTaskUpdatingSemaphore = new Semaphore(1); + this.topEntitiesInited = false; + this.lastScaleEntityTaskSlotsTime = Instant.now(); + this.latestTaskRunTime = Instant.now(); + } + + public void setTopEntityCount(Integer topEntityCount) { + this.refreshLatestTaskRunTime(); + this.topEntityCount = topEntityCount; + } + + public String[] getPendingEntities() { + return pendingEntities.toArray(new String[0]); + } + + public String[] getRunningEntities() { + return runningEntities.toArray(new String[0]); + } + + public String[] getTempEntities() { + return tempEntities.toArray(new String[0]); + } + + public Integer getTopEntityCount() { + return topEntityCount; + } + + public boolean tryAcquireTaskUpdatingSemaphore(long timeoutInMillis) throws InterruptedException { + return detectorTaskUpdatingSemaphore.tryAcquire(timeoutInMillis, TimeUnit.MILLISECONDS); + } + + public void releaseTaskUpdatingSemaphore() { + detectorTaskUpdatingSemaphore.release(); + } + + public boolean getTopEntitiesInited() { + return topEntitiesInited; + } + + public void setEntityTaskLanes(int entityTaskLanes) { + this.refreshLatestTaskRunTime(); + this.entityTaskLanes = new AtomicInteger(entityTaskLanes); + } + + public int getAndDecreaseEntityTaskLanes() { + return this.entityTaskLanes.getAndDecrement(); + } + + public int getEntityTaskLanes() { + return this.entityTaskLanes.get(); + } + + public void setTopEntitiesInited(boolean inited) { + this.topEntitiesInited = inited; + } + + public int getTaskRetryTimes(String taskId) { + return taskRetryTimes.computeIfAbsent(taskId, id -> new AtomicInteger(0)).get(); + } + + /** + * Remove entities from both temp and running entities queue and add list of entities into pending entity queue. + * @param entities a list of entity + */ + public void addPendingEntities(List entities) { + this.refreshLatestTaskRunTime(); + if (entities == null || entities.size() == 0) { + return; + } + for (String entity : entities) { + if (entity != null) { + // make sure we delete from temp and running queue first before adding the entity to pending queue + tempEntities.remove(entity); + runningEntities.remove(entity); + if (!pendingEntities.contains(entity)) { + pendingEntities.add(entity); + } + } + } + } + + /** + * Move entity to running entity queue. + * @param entity entity value + */ + public void moveToRunningEntity(String entity) { + this.refreshLatestTaskRunTime(); + if (entity == null) { + return; + } + boolean removed = this.tempEntities.remove(entity); + // It's possible that entity was removed before this function. Should check if + // task in temp queue or not before adding it to running queue. + if (removed && !this.runningEntities.contains(entity)) { + this.runningEntities.add(entity); + // clean it from pending queue to make sure entity only exists in running queue + this.pendingEntities.remove(entity); + } + } + + public int getPendingEntityCount() { + return this.pendingEntities.size(); + } + + public int getRunningEntityCount() { + return this.runningEntities.size(); + } + + public int getUnfinishedEntityCount() { + return this.runningEntities.size() + this.tempEntities.size() + this.pendingEntities.size(); + } + + public int getTempEntityCount() { + return this.tempEntities.size(); + } + + public Instant getLastScaleEntityTaskSlotsTime() { + return this.lastScaleEntityTaskSlotsTime; + } + + public void setLastScaleEntityTaskSlotsTime(Instant lastScaleEntityTaskSlotsTime) { + this.lastScaleEntityTaskSlotsTime = lastScaleEntityTaskSlotsTime; + } + + public Instant getLatestTaskRunTime() { + return latestTaskRunTime; + } + + public void refreshLatestTaskRunTime() { + this.latestTaskRunTime = Instant.now(); + } + + public boolean hasEntity() { + return !this.pendingEntities.isEmpty() || !this.runningEntities.isEmpty() || !this.tempEntities.isEmpty(); + } + + public boolean hasRunningEntity() { + return !this.runningEntities.isEmpty() || !this.tempEntities.isEmpty(); + } + + public boolean removeRunningEntity(String entity) { + // In normal case, the entity will be moved to running queue if entity task dispatched + // to worker node successfully. If failed to dispatch to worker node, it will still stay + // in temp queue, check ADBatchTaskRunner#workerNodeResponseListener. Then will send + // entity task done message to coordinating node to move to pending queue if exception + // is retryable or remove entity from cache if not retryable. + return this.runningEntities.remove(entity) || this.tempEntities.remove(entity); + } + + /** + * Clear pending/running/temp entities queues, task retry times and rate limiter cache. + */ + public void clear() { + this.pendingEntities.clear(); + this.runningEntities.clear(); + this.tempEntities.clear(); + this.taskRetryTimes.clear(); + } + + /** + * Poll one entity from pending entities queue. If entity exists, move it into + * temp entities queue. + * @return entity value + */ + public String pollEntity() { + this.refreshLatestTaskRunTime(); + String entity = this.pendingEntities.poll(); + if (entity != null && !this.tempEntities.contains(entity)) { + this.tempEntities.add(entity); + } + return entity; + } + + /** + * Clear pending entities queue. + */ + public void clearPendingEntities() { + this.pendingEntities.clear(); + } + + /** + * Increase task retry times by 1. + * @param taskId task id + * @return current retry time + */ + public int increaseTaskRetry(String taskId) { + return this.taskRetryTimes.computeIfAbsent(taskId, id -> new AtomicInteger(0)).getAndIncrement(); + } + + /** + * Check if entity exists in temp entities queue, pending entities queue or running + * entities queue. If exists, remove from these queues. + * @param entity entity value + * @return true if entity exists and removed + */ + public boolean removeEntity(String entity) { + this.refreshLatestTaskRunTime(); + if (entity == null) { + return false; + } + boolean removedFromTempQueue = tempEntities.remove(entity); + boolean removedFromPendingQueue = pendingEntities.remove(entity); + boolean removedFromRunningQueue = runningEntities.remove(entity); + return removedFromTempQueue || removedFromPendingQueue || removedFromRunningQueue; + } +} diff --git a/src/main/java/org/opensearch/ad/task/ADHCBatchTaskRunState.java-e b/src/main/java/org/opensearch/ad/task/ADHCBatchTaskRunState.java-e new file mode 100644 index 000000000..91f00b4cd --- /dev/null +++ b/src/main/java/org/opensearch/ad/task/ADHCBatchTaskRunState.java-e @@ -0,0 +1,92 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.task; + +import java.time.Instant; + +import org.opensearch.ad.model.ADTaskState; + +/** + * Cache HC batch task running state on coordinating and worker node. + */ +public class ADHCBatchTaskRunState { + + // HC batch task run state will expire after 60 seconds after last task run time or task cancelled time. + public static final int HC_TASK_RUN_STATE_TIMEOUT_IN_MILLIS = 60_000; + private String detectorTaskState; + // record if HC detector historical analysis cancelled/stopped. Every entity task should + // recheck this field and stop if it's true. + private boolean isHistoricalAnalysisCancelled; + private String cancelReason; + private String cancelledBy; + private Long lastTaskRunTimeInMillis; + private Long cancelledTimeInMillis; + + public ADHCBatchTaskRunState() { + this.detectorTaskState = ADTaskState.INIT.name(); + } + + public String getDetectorTaskState() { + return detectorTaskState; + } + + public void setDetectorTaskState(String detectorTaskState) { + this.detectorTaskState = detectorTaskState; + } + + public boolean getHistoricalAnalysisCancelled() { + return isHistoricalAnalysisCancelled; + } + + public void setHistoricalAnalysisCancelled(boolean historicalAnalysisCancelled) { + isHistoricalAnalysisCancelled = historicalAnalysisCancelled; + } + + public String getCancelReason() { + return cancelReason; + } + + public void setCancelReason(String cancelReason) { + this.cancelReason = cancelReason; + } + + public String getCancelledBy() { + return cancelledBy; + } + + public void setCancelledBy(String cancelledBy) { + this.cancelledBy = cancelledBy; + } + + public void setCancelledTimeInMillis(Long cancelledTimeInMillis) { + this.cancelledTimeInMillis = cancelledTimeInMillis; + } + + public void setLastTaskRunTimeInMillis(Long lastTaskRunTimeInMillis) { + this.lastTaskRunTimeInMillis = lastTaskRunTimeInMillis; + } + + public boolean expired() { + long nowInMillis = Instant.now().toEpochMilli(); + if (isHistoricalAnalysisCancelled + && cancelledTimeInMillis != null + && cancelledTimeInMillis + HC_TASK_RUN_STATE_TIMEOUT_IN_MILLIS < nowInMillis) { + return true; + } + if (!isHistoricalAnalysisCancelled + && lastTaskRunTimeInMillis != null + && lastTaskRunTimeInMillis + HC_TASK_RUN_STATE_TIMEOUT_IN_MILLIS < nowInMillis) { + return true; + } + return false; + } +} diff --git a/src/main/java/org/opensearch/ad/task/ADRealtimeTaskCache.java-e b/src/main/java/org/opensearch/ad/task/ADRealtimeTaskCache.java-e new file mode 100644 index 000000000..bf8cbb860 --- /dev/null +++ b/src/main/java/org/opensearch/ad/task/ADRealtimeTaskCache.java-e @@ -0,0 +1,93 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.task; + +import java.time.Instant; + +/** + * AD realtime task cache which will hold these data + * 1. task state + * 2. init progress + * 3. error + * 4. last job run time + * 5. detector interval + */ +public class ADRealtimeTaskCache { + + // task state + private String state; + + // init progress + private Float initProgress; + + // error + private String error; + + // track last job run time, will clean up cache if no access after 2 intervals + private long lastJobRunTime; + + // detector interval in milliseconds. + private long detectorIntervalInMillis; + + // we query result index to check if there are any result generated for detector to tell whether it passed initialization of not. + // To avoid repeated query when there is no data, record whether we have done that or not. + private boolean queriedResultIndex; + + public ADRealtimeTaskCache(String state, Float initProgress, String error, long detectorIntervalInMillis) { + this.state = state; + this.initProgress = initProgress; + this.error = error; + this.lastJobRunTime = Instant.now().toEpochMilli(); + this.detectorIntervalInMillis = detectorIntervalInMillis; + this.queriedResultIndex = false; + } + + public String getState() { + return state; + } + + public void setState(String state) { + this.state = state; + } + + public Float getInitProgress() { + return initProgress; + } + + public void setInitProgress(Float initProgress) { + this.initProgress = initProgress; + } + + public String getError() { + return error; + } + + public void setError(String error) { + this.error = error; + } + + public void setLastJobRunTime(long lastJobRunTime) { + this.lastJobRunTime = lastJobRunTime; + } + + public boolean hasQueriedResultIndex() { + return queriedResultIndex; + } + + public void setQueriedResultIndex(boolean queriedResultIndex) { + this.queriedResultIndex = queriedResultIndex; + } + + public boolean expired() { + return lastJobRunTime + 2 * detectorIntervalInMillis < Instant.now().toEpochMilli(); + } +} diff --git a/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java-e b/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java-e new file mode 100644 index 000000000..0df994963 --- /dev/null +++ b/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java-e @@ -0,0 +1,1393 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.task; + +import static org.opensearch.ad.MemoryTracker.Origin.HISTORICAL_SINGLE_ENTITY_DETECTOR; +import static org.opensearch.ad.constant.ADCommonMessages.DETECTOR_IS_RUNNING; +import static org.opensearch.ad.constant.ADCommonMessages.EXCEED_HISTORICAL_ANALYSIS_LIMIT; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_CACHED_DELETED_TASKS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.NUM_TREES; +import static org.opensearch.timeseries.util.ParseUtils.isNullOrEmpty; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Deque; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Queue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Semaphore; +import java.util.stream.Collectors; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskState; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.timeseries.common.exception.DuplicateTaskException; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; +import com.google.common.collect.ImmutableList; + +public class ADTaskCacheManager { + private final Logger logger = LogManager.getLogger(ADTaskCacheManager.class); + + private volatile Integer maxAdBatchTaskPerNode; + private volatile Integer maxCachedDeletedTask; + private final MemoryTracker memoryTracker; + private final int numberSize = 8; + public static final int TASK_RETRY_LIMIT = 3; + private final Semaphore cleanExpiredHCBatchTaskRunStatesSemaphore; + + // =================================================================== + // Fields below are caches on coordinating node + // =================================================================== + /** + * This field is to cache all detector level tasks which running on the + * coordinating node to resolve race condition. Will check if detector id + * exists in cache or not first. If user starts multiple tasks for the same + * detector, we will put the first task in cache and reject following tasks. + *

Node: coordinating node

+ *

Key: detector id; Value: detector level task id

+ */ + private Map detectorTasks; + /** + * This field is to cache all HC detector level data on coordinating node, like + * pending/running entities, check more details in comments of {@link ADHCBatchTaskCache}. + *

Node: coordinating node

+ *

Key: detector id

+ */ + private Map hcBatchTaskCaches; + /** + * This field is to cache all detectors' task slot and task lane limit on coordinating + * node. + *

Node: coordinating node

+ *

Key: detector id

+ */ + private Map detectorTaskSlotLimit; + /** + * This field is to cache all realtime tasks on coordinating node. + *

Node: coordinating node

+ *

Key is detector id

+ */ + private Map realtimeTaskCaches; + /** + * This field is to cache all deleted detector level tasks on coordinating node. + * Will try to clean up child task and AD result later. + *

Node: coordinating node

+ * Check {@link ADTaskManager#cleanChildTasksAndADResultsOfDeletedTask()} + */ + private Queue deletedDetectorTasks; + + // =================================================================== + // Fields below are caches on worker node + // =================================================================== + /** + * This field is to cache all batch tasks running on worker node. Both single + * entity detector task and HC entity task will be cached in this field. + *

Node: worker node

+ *

Key: task id

+ */ + private final Map batchTaskCaches; + + // =================================================================== + // Fields below are caches on both coordinating and worker node + // =================================================================== + /** + * This field is to cache HC detector batch task running state on worker node. + * For example, is detector historical analysis cancelled or not, HC detector + * level task state. + *

Node: worker node

+ *

Outer Key: detector Id; Inner Key: detector level task id

+ */ + private Map> hcBatchTaskRunState; + + // =================================================================== + // Fields below are caches on any data node serves delete detector + // request. Check ADTaskManager#deleteADResultOfDetector + // =================================================================== + /** + * This field is to cache deleted detector IDs. Hourly cron will poll this queue + * and clean AD results. Check {@link ADTaskManager#cleanADResultOfDeletedDetector()} + *

Node: any data node servers delete detector request

+ */ + private Queue deletedDetectors; + + /** + * Constructor to create AD task cache manager. + * + * @param settings ES settings + * @param clusterService ES cluster service + * @param memoryTracker AD memory tracker + */ + public ADTaskCacheManager(Settings settings, ClusterService clusterService, MemoryTracker memoryTracker) { + this.maxAdBatchTaskPerNode = MAX_BATCH_TASK_PER_NODE.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_BATCH_TASK_PER_NODE, it -> maxAdBatchTaskPerNode = it); + this.maxCachedDeletedTask = MAX_CACHED_DELETED_TASKS.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_CACHED_DELETED_TASKS, it -> maxCachedDeletedTask = it); + this.batchTaskCaches = new ConcurrentHashMap<>(); + this.memoryTracker = memoryTracker; + this.detectorTasks = new ConcurrentHashMap<>(); + this.hcBatchTaskCaches = new ConcurrentHashMap<>(); + this.realtimeTaskCaches = new ConcurrentHashMap<>(); + this.deletedDetectorTasks = new ConcurrentLinkedQueue<>(); + this.deletedDetectors = new ConcurrentLinkedQueue<>(); + this.detectorTaskSlotLimit = new ConcurrentHashMap<>(); + this.hcBatchTaskRunState = new ConcurrentHashMap<>(); + this.cleanExpiredHCBatchTaskRunStatesSemaphore = new Semaphore(1); + } + + /** + * Put AD task into cache. + * If AD task is already in cache, will throw {@link IllegalArgumentException} + * If there is one AD task in cache for detector, will throw {@link IllegalArgumentException} + * If there is not enough memory for this AD task, will throw {@link LimitExceededException} + * + * @param adTask AD task + */ + public synchronized void add(ADTask adTask) { + String taskId = adTask.getTaskId(); + String detectorId = adTask.getId(); + if (contains(taskId)) { + throw new DuplicateTaskException(DETECTOR_IS_RUNNING); + } + // It's possible that multiple entity tasks of one detector run on same data node. + if (!adTask.isEntityTask() && containsTaskOfDetector(detectorId)) { + throw new DuplicateTaskException(DETECTOR_IS_RUNNING); + } + checkRunningTaskLimit(); + long neededCacheSize = calculateADTaskCacheSize(adTask); + if (!memoryTracker.canAllocateReserved(neededCacheSize)) { + throw new LimitExceededException("Not enough memory to run detector"); + } + memoryTracker.consumeMemory(neededCacheSize, true, HISTORICAL_SINGLE_ENTITY_DETECTOR); + ADBatchTaskCache taskCache = new ADBatchTaskCache(adTask); + taskCache.getCacheMemorySize().set(neededCacheSize); + batchTaskCaches.put(taskId, taskCache); + if (adTask.isEntityTask()) { + ADHCBatchTaskRunState hcBatchTaskRunState = getHCBatchTaskRunState(detectorId, adTask.getDetectorLevelTaskId()); + if (hcBatchTaskRunState != null) { + hcBatchTaskRunState.setLastTaskRunTimeInMillis(Instant.now().toEpochMilli()); + } + } + // clean expired HC batch task run states when new task starts on worker node. + cleanExpiredHCBatchTaskRunStates(); + } + + /** + * Put detector id in running detector cache. + * + * @param detectorId detector id + * @param adTask AD task + * @throws DuplicateTaskException throw DuplicateTaskException when the detector id already in cache + */ + public synchronized void add(String detectorId, ADTask adTask) { + if (detectorTasks.containsKey(detectorId)) { + logger.warn("detector is already in running detector cache, detectorId: " + detectorId); + throw new DuplicateTaskException(DETECTOR_IS_RUNNING); + } + logger.info("add detector in running detector cache, detectorId: {}, taskId: {}", detectorId, adTask.getTaskId()); + this.detectorTasks.put(detectorId, adTask.getTaskId()); + if (ADTaskType.HISTORICAL_HC_DETECTOR.name().equals(adTask.getTaskType())) { + ADHCBatchTaskCache adhcBatchTaskCache = new ADHCBatchTaskCache(); + this.hcBatchTaskCaches.put(detectorId, adhcBatchTaskCache); + } + // If new historical analysis starts, clean its old batch task run state directly. + hcBatchTaskRunState.remove(detectorId); + } + + /** + * check if current running batch task on current node exceeds + * max running task limitation. + * If executing task count exceeds limitation, will throw + * {@link LimitExceededException} + */ + public void checkRunningTaskLimit() { + if (size() >= maxAdBatchTaskPerNode) { + String error = EXCEED_HISTORICAL_ANALYSIS_LIMIT + ": " + maxAdBatchTaskPerNode; + throw new LimitExceededException(error); + } + } + + /** + * Get task RCF model. + * If task doesn't exist in cache, will throw {@link java.lang.IllegalArgumentException}. + * + * @param taskId AD task id + * @return RCF model + */ + public ThresholdedRandomCutForest getTRcfModel(String taskId) { + return getBatchTaskCache(taskId).getTRcfModel(); + } + + /** + * Get threshhold model training data size in bytes. + * + * @param taskId task id + * @return training data size in bytes + */ + public int getThresholdModelTrainingDataSize(String taskId) { + return getBatchTaskCache(taskId).getThresholdModelTrainingDataSize().get(); + } + + /** + * Threshold model trained or not. + * If task doesn't exist in cache, will throw {@link java.lang.IllegalArgumentException}. + * + * @param taskId AD task id + * @return true if threshold model trained; otherwise, return false + */ + public boolean isThresholdModelTrained(String taskId) { + return getBatchTaskCache(taskId).isThresholdModelTrained(); + } + + /** + * Set threshold model trained or not. + * + * @param taskId task id + * @param trained threshold model trained or not + */ + protected void setThresholdModelTrained(String taskId, boolean trained) { + ADBatchTaskCache taskCache = getBatchTaskCache(taskId); + taskCache.setThresholdModelTrained(trained); + } + + /** + * Get shingle data. + * + * @param taskId AD task id + * @return shingle data + */ + public Deque>> getShingle(String taskId) { + return getBatchTaskCache(taskId).getShingle(); + } + + /** + * Check if task exists in cache. + * + * @param taskId task id + * @return true if task exists in cache; otherwise, return false. + */ + public boolean contains(String taskId) { + return batchTaskCaches.containsKey(taskId); + } + + /** + * Check if there is task in cache for detector. + * + * @param detectorId detector id + * @return true if there is task in cache; otherwise return false + */ + public boolean containsTaskOfDetector(String detectorId) { + return batchTaskCaches.values().stream().filter(v -> Objects.equals(detectorId, v.getId())).findAny().isPresent(); + } + + /** + * Get task id list of detector. + * + * @param detectorId detector id + * @return list of task id + */ + public List getTasksOfDetector(String detectorId) { + return batchTaskCaches + .values() + .stream() + .filter(v -> Objects.equals(detectorId, v.getId())) + .map(c -> c.getTaskId()) + .collect(Collectors.toList()); + } + + /** + * Get batch task cache. If task doesn't exist in cache, will throw + * {@link java.lang.IllegalArgumentException} + * We throw exception rather than return {@code Optional.empty} or null + * here, so don't need to check task existence by writing duplicate null + * checking code. All AD task exceptions will be handled in AD task manager. + * + * @param taskId task id + * @return AD batch task cache + */ + private ADBatchTaskCache getBatchTaskCache(String taskId) { + if (!contains(taskId)) { + throw new IllegalArgumentException("AD task not in cache"); + } + return batchTaskCaches.get(taskId); + } + + private List getBatchTaskCacheByDetectorId(String detectorId) { + return batchTaskCaches.values().stream().filter(v -> Objects.equals(detectorId, v.getId())).collect(Collectors.toList()); + } + + /** + * Calculate AD task cache memory usage. + * + * @param adTask AD task + * @return how many bytes will consume + */ + private long calculateADTaskCacheSize(ADTask adTask) { + AnomalyDetector detector = adTask.getDetector(); + int dimension = detector.getEnabledFeatureIds().size() * detector.getShingleSize(); + return memoryTracker + .estimateTRCFModelSize( + dimension, + NUM_TREES, + AnomalyDetectorSettings.BATCH_BOUNDING_BOX_CACHE_RATIO, + detector.getShingleSize().intValue(), + false + ) + shingleMemorySize(detector.getShingleSize(), detector.getEnabledFeatureIds().size()); + } + + /** + * Get RCF model size in bytes. + * + * @param taskId task id + * @return model size in bytes + */ + public long getModelSize(String taskId) { + ADBatchTaskCache batchTaskCache = getBatchTaskCache(taskId); + ThresholdedRandomCutForest tRCF = batchTaskCache.getTRcfModel(); + RandomCutForest rcfForest = tRCF.getForest(); + int dimensions = rcfForest.getDimensions(); + int numberOfTrees = rcfForest.getNumberOfTrees(); + return memoryTracker + .estimateTRCFModelSize(dimensions, numberOfTrees, AnomalyDetectorSettings.BATCH_BOUNDING_BOX_CACHE_RATIO, 1, false); + } + + /** + * Remove task from cache and refresh last run time of HC batch task run state. + * Don't remove all detector cache here as it's possible that some entity task running on other worker nodes + * + * @param taskId AD task id + * @param detectorId detector id + * @param detectorTaskId detector level task id + */ + public void remove(String taskId, String detectorId, String detectorTaskId) { + ADBatchTaskCache taskCache = batchTaskCaches.get(taskId); + if (taskCache != null) { + logger.debug("Remove batch task from cache, task id: {}", taskId); + memoryTracker.releaseMemory(taskCache.getCacheMemorySize().get(), true, HISTORICAL_SINGLE_ENTITY_DETECTOR); + batchTaskCaches.remove(taskId); + ADHCBatchTaskRunState hcBatchTaskRunState = getHCBatchTaskRunState(detectorId, detectorTaskId); + if (hcBatchTaskRunState != null) { + hcBatchTaskRunState.setLastTaskRunTimeInMillis(Instant.now().toEpochMilli()); + } + } + } + + /** + * Only remove detector cache if no running entities. + * + * @param detectorId detector id + */ + public void removeHistoricalTaskCacheIfNoRunningEntity(String detectorId) { + ADHCBatchTaskCache taskCache = hcBatchTaskCaches.get(detectorId); + if (taskCache != null) { + if (taskCache.hasRunningEntity()) { + throw new IllegalArgumentException("HC detector still has running entities"); + } + } + removeHistoricalTaskCache(detectorId); + } + + /** + * Remove detector id from running detector cache + * + * @param detectorId detector id + */ + public void removeHistoricalTaskCache(String detectorId) { + ADHCBatchTaskCache taskCache = hcBatchTaskCaches.get(detectorId); + if (taskCache != null) { + // this will happen only on coordinating node. When worker nodes left, + // we will reset task state as STOPPED and clean up cache, add this warning + // to make it easier to debug issue. + if (hasEntity(detectorId)) { + logger + .warn( + "There are still entities for detector. pending: {}, running: {}, temp: {}", + Arrays.toString(taskCache.getPendingEntities()), + Arrays.toString(taskCache.getRunningEntities()), + Arrays.toString(taskCache.getTempEntities()) + ); + } + taskCache.clear(); + hcBatchTaskCaches.remove(detectorId); + } + List tasksOfDetector = getTasksOfDetector(detectorId); + for (String taskId : tasksOfDetector) { + remove(taskId, null, null); + } + if (tasksOfDetector.size() > 0) { + logger + .warn( + "Removed historical AD task from cache for detector {}, taskId: {}", + detectorId, + Arrays.toString(tasksOfDetector.toArray(new String[0])) + ); + } + if (detectorTasks.containsKey(detectorId)) { + detectorTasks.remove(detectorId); + logger.info("Removed detector from AD task coordinating node cache, detectorId: " + detectorId); + } + detectorTaskSlotLimit.remove(detectorId); + hcBatchTaskRunState.remove(detectorId); + } + + /** + * Cancel AD task by detector id. + * + * @param detectorId detector id + * @param detectorTaskId detector level task id + * @param reason why need to cancel task + * @param userName user name + * @return AD task cancellation state + */ + public ADTaskCancellationState cancelByDetectorId(String detectorId, String detectorTaskId, String reason, String userName) { + if (detectorId == null || detectorTaskId == null) { + throw new IllegalArgumentException("Can't cancel task for null detector id or detector task id"); + } + ADHCBatchTaskCache hcTaskCache = hcBatchTaskCaches.get(detectorId); + List taskCaches = getBatchTaskCacheByDetectorId(detectorId); + if (hcTaskCache != null) { + // coordinating node + logger.debug("Set HC historical analysis as cancelled for detector {}", detectorId); + hcTaskCache.clearPendingEntities(); + hcTaskCache.setEntityTaskLanes(0); + } + ADHCBatchTaskRunState taskStateCache = getOrCreateHCDetectorTaskStateCache(detectorId, detectorTaskId); + taskStateCache.setCancelledTimeInMillis(Instant.now().toEpochMilli()); + taskStateCache.setHistoricalAnalysisCancelled(true); + taskStateCache.setCancelReason(reason); + taskStateCache.setCancelledBy(userName); + + if (isNullOrEmpty(taskCaches)) { + return ADTaskCancellationState.NOT_FOUND; + } + + ADTaskCancellationState cancellationState = ADTaskCancellationState.ALREADY_CANCELLED; + for (ADBatchTaskCache cache : taskCaches) { + if (!cache.isCancelled()) { + cancellationState = ADTaskCancellationState.CANCELLED; + cache.cancel(reason, userName); + } + } + return cancellationState; + } + + /** + * Check if single entity detector level task or HC entity task is cancelled or not. + * + * @param taskId AD task id, should not be HC detector level task + * @return true if task is cancelled; otherwise return false + */ + public boolean isCancelled(String taskId) { + // For HC detector, ADBatchTaskCache is entity task. + ADBatchTaskCache taskCache = getBatchTaskCache(taskId); + String detectorId = taskCache.getId(); + String detectorTaskId = taskCache.getDetectorTaskId(); + + ADHCBatchTaskRunState taskStateCache = getHCBatchTaskRunState(detectorId, detectorTaskId); + boolean hcDetectorStopped = false; + if (taskStateCache != null) { + hcDetectorStopped = taskStateCache.getHistoricalAnalysisCancelled(); + } + // If a new entity task comes after cancel event, then we have no chance to set it as cancelled. + // So we need to check hcDetectorStopped for HC detector to know if it's cancelled or not. + // For single entity detector, it has just 1 task, just need to check taskCache.isCancelled. + return taskCache.isCancelled() || hcDetectorStopped; + } + + /** + * Get current task count in cache. + * + * @return task count + */ + public int size() { + return batchTaskCaches.size(); + } + + /** + * Clear all tasks. + */ + public void clear() { + batchTaskCaches.clear(); + detectorTasks.clear(); + } + + /** + * Estimate max memory usage of model training data. + * The training data is double and will cache in double array. + * One double consumes 8 bytes. + * + * @param size training data point count + * @return how many bytes will consume + */ + public long trainingDataMemorySize(int size) { + return numberSize * size; + } + + /** + * Estimate max memory usage of shingle data. + * One feature aggregated data point(double) consumes 8 bytes. + * The shingle data is stored in {@link java.util.Deque}. From testing, + * other parts except feature data consume 80 bytes. + * + * Check {@link ADBatchTaskCache#getShingle()} + * + * @param shingleSize shingle data point count + * @param enabledFeatureSize enabled feature count + * @return how many bytes will consume + */ + public long shingleMemorySize(int shingleSize, int enabledFeatureSize) { + return (80 + numberSize * enabledFeatureSize) * shingleSize; + } + + /** + * HC top entity initied or not + * + * @param detectorId detector id + * @return true if top entity inited; otherwise return false + */ + public synchronized boolean topEntityInited(String detectorId) { + return hcBatchTaskCaches.containsKey(detectorId) ? hcBatchTaskCaches.get(detectorId).getTopEntitiesInited() : false; + } + + /** + * Set top entity inited as true. + * + * @param detectorId detector id + */ + public void setTopEntityInited(String detectorId) { + getExistingHCTaskCache(detectorId).setTopEntitiesInited(true); + } + + /** + * Get pending to run entity count. + * + * @param detectorId detector id + * @return entity count + */ + public int getPendingEntityCount(String detectorId) { + return hcBatchTaskCaches.containsKey(detectorId) ? hcBatchTaskCaches.get(detectorId).getPendingEntityCount() : 0; + } + + /** + * Get current running entity count in cache of detector. + * + * @param detectorId detector id + * @return count of detector's running entity in cache + */ + public int getRunningEntityCount(String detectorId) { + ADHCBatchTaskCache taskCache = hcBatchTaskCaches.get(detectorId); + if (taskCache != null) { + return taskCache.getRunningEntityCount(); + } + return 0; + } + + public int getTempEntityCount(String detectorId) { + ADHCBatchTaskCache taskCache = hcBatchTaskCaches.get(detectorId); + if (taskCache != null) { + return taskCache.getTempEntityCount(); + } + return 0; + } + + /** + * Get total top entity count for detector. + * + * @param detectorId detector id + * @return total top entity count + */ + public Integer getTopEntityCount(String detectorId) { + ADHCBatchTaskCache batchTaskCache = hcBatchTaskCaches.get(detectorId); + if (batchTaskCache != null) { + return batchTaskCache.getTopEntityCount(); + } else { + return 0; + } + } + + /** + * Get current running entities of detector. + * Profile API will call this method. + * + * @param detectorId detector id + * @return detector's running entities in cache + */ + public List getRunningEntities(String detectorId) { + if (hcBatchTaskCaches.containsKey(detectorId)) { + ADHCBatchTaskCache hcTaskCache = getExistingHCTaskCache(detectorId); + return Arrays.asList(hcTaskCache.getRunningEntities()); + } + return null; + } + + /** + * Set max allowed running entities for HC detector. + * + * @param detectorId detector id + * @param allowedRunningEntities max allowed running entities + */ + public void setAllowedRunningEntities(String detectorId, int allowedRunningEntities) { + logger.debug("Set allowed running entities of detector {} as {}", detectorId, allowedRunningEntities); + getExistingHCTaskCache(detectorId).setEntityTaskLanes(allowedRunningEntities); + } + + /** + * Set detector task slots. We cache task slots assigned to detector on coordinating node. + * When start new historical analysis, will gather detector task slots on all nodes and + * check how many task slots available for new historical analysis. + * + * @param detectorId detector id + * @param taskSlots task slots + */ + public synchronized void setDetectorTaskSlots(String detectorId, int taskSlots) { + logger.debug("Set task slots of detector {} as {}", detectorId, taskSlots); + ADTaskSlotLimit adTaskSlotLimit = detectorTaskSlotLimit + .computeIfAbsent(detectorId, key -> new ADTaskSlotLimit(taskSlots, taskSlots)); + adTaskSlotLimit.setDetectorTaskSlots(taskSlots); + } + + /** + * Scale up detector task slots. + * @param detectorId detector id + * @param delta scale delta + */ + public synchronized void scaleUpDetectorTaskSlots(String detectorId, int delta) { + ADTaskSlotLimit adTaskSlotLimit = detectorTaskSlotLimit.get(detectorId); + int taskSlots = this.getDetectorTaskSlots(detectorId); + if (adTaskSlotLimit != null && delta > 0) { + int newTaskSlots = adTaskSlotLimit.getDetectorTaskSlots() + delta; + logger.info("Scale up task slots of detector {} from {} to {}", detectorId, taskSlots, newTaskSlots); + adTaskSlotLimit.setDetectorTaskSlots(newTaskSlots); + } + } + + /** + * Check how many unfinished entities in cache. If it's less than detector task slots, we + * can scale down detector task slots to same as unfinished entities count. We can save + * task slots in this way. The released task slots can be reused for other task run. + * @param detectorId detector id + * @param delta scale delta + * @return new task slots + */ + public synchronized int scaleDownHCDetectorTaskSlots(String detectorId, int delta) { + ADTaskSlotLimit adTaskSlotLimit = this.detectorTaskSlotLimit.get(detectorId); + int taskSlots = this.getDetectorTaskSlots(detectorId); + if (adTaskSlotLimit != null && delta > 0) { + int newTaskSlots = taskSlots - delta; + if (newTaskSlots > 0) { + logger.info("Scale down task slots of detector {} from {} to {}", detectorId, taskSlots, newTaskSlots); + adTaskSlotLimit.setDetectorTaskSlots(newTaskSlots); + return newTaskSlots; + } + } + return taskSlots; + } + + /** + * Set detector task lane limit. + * @param detectorId detector id + * @param taskLaneLimit task lane limit + */ + public synchronized void setDetectorTaskLaneLimit(String detectorId, int taskLaneLimit) { + ADTaskSlotLimit adTaskSlotLimit = detectorTaskSlotLimit.get(detectorId); + if (adTaskSlotLimit != null) { + adTaskSlotLimit.setDetectorTaskLaneLimit(taskLaneLimit); + } + } + + /** + * Get how many task slots assigned to detector + * @param detectorId detector id + * @return detector task slot count + */ + public int getDetectorTaskSlots(String detectorId) { + ADTaskSlotLimit taskSlotLimit = detectorTaskSlotLimit.get(detectorId); + if (taskSlotLimit != null) { + return taskSlotLimit.getDetectorTaskSlots(); + } + return 0; + } + + public int getUnfinishedEntityCount(String detectorId) { + ADHCBatchTaskCache taskCache = hcBatchTaskCaches.get(detectorId); + if (taskCache != null) { + return taskCache.getUnfinishedEntityCount(); + } + return 0; + } + + /** + * Get total task slots on this node. + * @return total task slots + */ + public int getTotalDetectorTaskSlots() { + int totalTaskSLots = 0; + for (Map.Entry entry : detectorTaskSlotLimit.entrySet()) { + totalTaskSLots += entry.getValue().getDetectorTaskSlots(); + } + return totalTaskSLots; + } + + public int getTotalBatchTaskCount() { + return batchTaskCaches.size(); + } + + /** + * Get current allowed entity task lanes and decrease it by 1. + * + * @param detectorId detector id + * @return current allowed entity task lane count + */ + public synchronized int getAndDecreaseEntityTaskLanes(String detectorId) { + return getExistingHCTaskCache(detectorId).getAndDecreaseEntityTaskLanes(); + } + + /** + * Get current available new entity task lanes. + * @param detectorId detector id + * @return how many task lane available now + */ + public int getAvailableNewEntityTaskLanes(String detectorId) { + return getExistingHCTaskCache(detectorId).getEntityTaskLanes(); + } + + private ADHCBatchTaskCache getExistingHCTaskCache(String detectorId) { + ADHCBatchTaskCache taskCache = hcBatchTaskCaches.get(detectorId); + if (taskCache != null) { + return taskCache; + } else { + throw new IllegalArgumentException("Can't find HC detector in cache"); + } + } + + /** + * Add list of entities into pending entities queue. And will remove these entities + * from temp entities queue. + * + * @param detectorId detector id + * @param entities list of entities + */ + public void addPendingEntities(String detectorId, List entities) { + getExistingHCTaskCache(detectorId).addPendingEntities(entities); + } + + /** + * Check if there is any HC task running on current node. + * @param detectorId detector id + * @return true if find detector id in any entity task or HC cache + */ + public boolean isHCTaskRunning(String detectorId) { + if (isHCTaskCoordinatingNode(detectorId)) { + return true; + } + // Only running tasks will be in cache. + Optional entityTask = this.batchTaskCaches + .values() + .stream() + .filter(cache -> Objects.equals(detectorId, cache.getId()) && cache.getEntity() != null) + .findFirst(); + return entityTask.isPresent(); + } + + /** + * Check if current node is coordianting node of HC detector. + * @param detectorId detector id + * @return true if find detector id in HC cache + */ + public boolean isHCTaskCoordinatingNode(String detectorId) { + return hcBatchTaskCaches.containsKey(detectorId); + } + + /** + * Set top entity count. + * + * @param detectorId detector id + * @param count top entity count + */ + public void setTopEntityCount(String detectorId, Integer count) { + ADHCBatchTaskCache hcTaskCache = getExistingHCTaskCache(detectorId); + hcTaskCache.setTopEntityCount(count); + ADTaskSlotLimit adTaskSlotLimit = detectorTaskSlotLimit.get(detectorId); + + if (count != null && adTaskSlotLimit != null) { + Integer detectorTaskSlots = adTaskSlotLimit.getDetectorTaskSlots(); + if (detectorTaskSlots != null && detectorTaskSlots > count) { + logger.debug("Scale down task slots from {} to the same as top entity count {}", detectorTaskSlots, count); + adTaskSlotLimit.setDetectorTaskSlots(count); + } + } + } + + /** + * Poll one entity from HC detector entities cache. If entity exists, will move + * entity to temp entites cache; otherwise return null. + * + * @param detectorId detector id + * @return one entity + */ + public synchronized String pollEntity(String detectorId) { + if (this.hcBatchTaskCaches.containsKey(detectorId)) { + ADHCBatchTaskCache hcTaskCache = this.hcBatchTaskCaches.get(detectorId); + String entity = hcTaskCache.pollEntity(); + return entity; + } else { + return null; + } + } + + /** + * Add entity into pending entities queue. And will remove the entity from temp + * and running entities queue. + * + * @param detectorId detector id + * @param entity entity value + */ + public void addPendingEntity(String detectorId, String entity) { + addPendingEntities(detectorId, ImmutableList.of(entity)); + } + + /** + * Move one entity to running entity queue. + * + * @param detectorId detector id + * @param entity entity value + */ + public void moveToRunningEntity(String detectorId, String entity) { + ADHCBatchTaskCache hcTaskCache = hcBatchTaskCaches.get(detectorId); + if (hcTaskCache != null) { + hcTaskCache.moveToRunningEntity(entity); + } + } + + /** + * Task exceeds max retry limit or not. + * + * @param detectorId detector id + * @param taskId task id + * @return true if exceed retry limit; otherwise return false + */ + public boolean exceedRetryLimit(String detectorId, String taskId) { + return getExistingHCTaskCache(detectorId).getTaskRetryTimes(taskId) > TASK_RETRY_LIMIT; + } + + /** + * Push entity back to the end of pending entity queue. + * + * @param taskId task id + * @param detectorId detector id + * @param entity entity value + */ + public void pushBackEntity(String taskId, String detectorId, String entity) { + addPendingEntity(detectorId, entity); + increaseEntityTaskRetry(detectorId, taskId); + } + + /** + * Increase entity task retry times. + * + * @param detectorId detector id + * @param taskId task id + * @return how many times retried + */ + public int increaseEntityTaskRetry(String detectorId, String taskId) { + return getExistingHCTaskCache(detectorId).increaseTaskRetry(taskId); + } + + /** + * Remove entity from cache. + * + * @param detectorId detector id + * @param entity entity value + */ + public void removeEntity(String detectorId, String entity) { + if (hcBatchTaskCaches.containsKey(detectorId)) { + hcBatchTaskCaches.get(detectorId).removeEntity(entity); + } + } + + /** + * Return AD task's entity list. + * + * @param taskId AD task id + * @return entity + */ + public Entity getEntity(String taskId) { + return getBatchTaskCache(taskId).getEntity(); + } + + /** + * Check if detector still has entity in cache. + * + * @param detectorId detector id + * @return true if detector still has entity in cache + */ + public synchronized boolean hasEntity(String detectorId) { + return hcBatchTaskCaches.containsKey(detectorId) && hcBatchTaskCaches.get(detectorId).hasEntity(); + } + + /** + * Remove entity from HC task running entity cache. + * + * @param detectorId detector id + * @param entity entity + * @return true if entity was removed as a result of this call + */ + public boolean removeRunningEntity(String detectorId, String entity) { + ADHCBatchTaskCache hcTaskCache = hcBatchTaskCaches.get(detectorId); + if (hcTaskCache != null) { + boolean removed = hcTaskCache.removeRunningEntity(entity); + logger.debug("Remove entity from running entities cache: {}: {}", entity, removed); + return removed; + } + return false; + } + + /** + * Try to get semaphore to update detector task. + * + * If the timeout is less than or equal to zero, will not wait at all to get 1 permit. + * If permit is available, will acquire 1 permit and return true immediately. If no permit, + * will wait for other thread release. If no permit available until timeout elapses, will + * return false. + * + * @param detectorId detector id + * @param timeoutInMillis timeout in milliseconds to wait for a permit, zero or negative means don't wait at all + * @return true if can get semaphore + * @throws InterruptedException if the current thread is interrupted + */ + public boolean tryAcquireTaskUpdatingSemaphore(String detectorId, long timeoutInMillis) throws InterruptedException { + ADHCBatchTaskCache taskCache = hcBatchTaskCaches.get(detectorId); + if (taskCache != null) { + return taskCache.tryAcquireTaskUpdatingSemaphore(timeoutInMillis); + } + return false; + } + + /** + * Try to release semaphore of updating detector task. + * @param detectorId detector id + */ + public void releaseTaskUpdatingSemaphore(String detectorId) { + ADHCBatchTaskCache taskCache = hcBatchTaskCaches.get(detectorId); + if (taskCache != null) { + taskCache.releaseTaskUpdatingSemaphore(); + } + } + + /** + * Clear pending entities of HC detector. + * + * @param detectorId detector id + */ + public void clearPendingEntities(String detectorId) { + ADHCBatchTaskCache taskCache = hcBatchTaskCaches.get(detectorId); + if (taskCache != null) { + taskCache.clearPendingEntities(); + } + } + + /** + * Check if realtime task field value change needed or not by comparing with cache. + * 1. If new field value is null, will consider changed needed to this field. + * 2. will consider the real time task change needed if + * 1) init progress is larger or the old init progress is null, or + * 2) if the state is different, and it is not changing from running to init. + * for other fields, as long as field values changed, will consider the realtime + * task change needed. We did this so that the init progress or state won't go backwards. + * 3. If realtime task cache not found, will consider the realtime task change needed. + * + * @param detectorId detector id + * @param newState new task state + * @param newInitProgress new init progress + * @param newError new error + * @return true if realtime task change needed. + */ + public boolean isRealtimeTaskChangeNeeded(String detectorId, String newState, Float newInitProgress, String newError) { + if (realtimeTaskCaches.containsKey(detectorId)) { + ADRealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(detectorId); + boolean stateChangeNeeded = false; + String oldState = realtimeTaskCache.getState(); + if (newState != null + && !newState.equals(oldState) + && !(ADTaskState.INIT.name().equals(newState) && ADTaskState.RUNNING.name().equals(oldState))) { + stateChangeNeeded = true; + } + boolean initProgressChangeNeeded = false; + Float existingProgress = realtimeTaskCache.getInitProgress(); + if (newInitProgress != null + && !newInitProgress.equals(existingProgress) + && (existingProgress == null || newInitProgress > existingProgress)) { + initProgressChangeNeeded = true; + } + boolean errorChanged = false; + if (newError != null && !newError.equals(realtimeTaskCache.getError())) { + errorChanged = true; + } + if (stateChangeNeeded || initProgressChangeNeeded || errorChanged) { + return true; + } + return false; + } else { + return true; + } + } + + /** + * Update realtime task cache with new field values. If realtime task cache exist, update it + * directly if task is not done; if task is done, remove the detector's realtime task cache. + * + * If realtime task cache doesn't exist, will do nothing. Next realtime job run will re-init + * realtime task cache when it finds task cache not inited yet. + * Check {@link ADTaskManager#initRealtimeTaskCacheAndCleanupStaleCache(String, AnomalyDetector, TransportService, ActionListener)}, + * {@link ADTaskManager#updateLatestRealtimeTaskOnCoordinatingNode(String, String, Long, Long, String, ActionListener)} + * + * @param detectorId detector id + * @param newState new task state + * @param newInitProgress new init progress + * @param newError new error + */ + public void updateRealtimeTaskCache(String detectorId, String newState, Float newInitProgress, String newError) { + ADRealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(detectorId); + if (realtimeTaskCache != null) { + if (newState != null) { + realtimeTaskCache.setState(newState); + } + if (newInitProgress != null) { + realtimeTaskCache.setInitProgress(newInitProgress); + } + if (newError != null) { + realtimeTaskCache.setError(newError); + } + if (newState != null && !ADTaskState.NOT_ENDED_STATES.contains(newState)) { + // If task is done, will remove its realtime task cache. + logger.info("Realtime task done with state {}, remove RT task cache for detector ", newState, detectorId); + removeRealtimeTaskCache(detectorId); + } + } else { + logger.debug("Realtime task cache is not inited yet for detector {}", detectorId); + } + } + + public void initRealtimeTaskCache(String detectorId, long detectorIntervalInMillis) { + realtimeTaskCaches.put(detectorId, new ADRealtimeTaskCache(null, null, null, detectorIntervalInMillis)); + logger.debug("Realtime task cache inited"); + } + + public void refreshRealtimeJobRunTime(String detectorId) { + ADRealtimeTaskCache taskCache = realtimeTaskCaches.get(detectorId); + if (taskCache != null) { + taskCache.setLastJobRunTime(Instant.now().toEpochMilli()); + } + } + + /** + * Get detector IDs from realtime task cache. + * @return array of detector id + */ + public String[] getDetectorIdsInRealtimeTaskCache() { + return realtimeTaskCaches.keySet().toArray(new String[0]); + } + + /** + * Remove detector's realtime task from cache. + * @param detectorId detector id + */ + public void removeRealtimeTaskCache(String detectorId) { + if (realtimeTaskCaches.containsKey(detectorId)) { + logger.info("Delete realtime cache for detector {}", detectorId); + realtimeTaskCaches.remove(detectorId); + } + } + + public ADRealtimeTaskCache getRealtimeTaskCache(String detectorId) { + return realtimeTaskCaches.get(detectorId); + } + + /** + * Clear realtime task cache. + */ + public void clearRealtimeTaskCache() { + realtimeTaskCaches.clear(); + } + + /** + * Add deleted task's id to deleted detector tasks queue. + * @param taskId task id + */ + public void addDeletedDetectorTask(String taskId) { + if (deletedDetectorTasks.size() < maxCachedDeletedTask) { + deletedDetectorTasks.add(taskId); + } + } + + /** + * Check if deleted task queue has items. + * @return true if has deleted detector task in cache + */ + public boolean hasDeletedDetectorTask() { + return !deletedDetectorTasks.isEmpty(); + } + + /** + * Poll one deleted detector task. + * @return task id + */ + public String pollDeletedDetectorTask() { + return this.deletedDetectorTasks.poll(); + } + + /** + * Add deleted detector's id to deleted detector queue. + * @param detectorId detector id + */ + public void addDeletedDetector(String detectorId) { + if (deletedDetectors.size() < maxCachedDeletedTask) { + deletedDetectors.add(detectorId); + } + } + + /** + * Poll one deleted detector. + * @return detector id + */ + public String pollDeletedDetector() { + return this.deletedDetectors.poll(); + } + + public String getDetectorTaskId(String detectorId) { + return detectorTasks.get(detectorId); + } + + public Instant getLastScaleEntityTaskLaneTime(String detectorId) { + ADHCBatchTaskCache taskCache = hcBatchTaskCaches.get(detectorId); + if (taskCache != null) { + return taskCache.getLastScaleEntityTaskSlotsTime(); + } + return null; + } + + public void refreshLastScaleEntityTaskLaneTime(String detectorId) { + ADHCBatchTaskCache taskCache = hcBatchTaskCaches.get(detectorId); + if (taskCache != null) { + taskCache.setLastScaleEntityTaskSlotsTime(Instant.now()); + } + } + + public Instant getLatestHCTaskRunTime(String detectorId) { + ADHCBatchTaskCache taskCache = hcBatchTaskCaches.get(detectorId); + if (taskCache != null) { + return taskCache.getLatestTaskRunTime(); + } + return null; + } + + public void refreshLatestHCTaskRunTime(String detectorId) { + ADHCBatchTaskCache taskCache = hcBatchTaskCaches.get(detectorId); + if (taskCache != null) { + taskCache.refreshLatestTaskRunTime(); + } + } + + /** + * Update detector level task's state in cache. + * @param detectorId detector id + * @param detectorTaskId detector level task id + * @param newState new state + */ + public synchronized void updateDetectorTaskState(String detectorId, String detectorTaskId, String newState) { + ADHCBatchTaskRunState cache = getOrCreateHCDetectorTaskStateCache(detectorId, detectorTaskId); + if (cache != null) { + cache.setDetectorTaskState(newState); + cache.setLastTaskRunTimeInMillis(Instant.now().toEpochMilli()); + } + } + + public ADHCBatchTaskRunState getOrCreateHCDetectorTaskStateCache(String detectorId, String detectorTaskId) { + Map states = hcBatchTaskRunState.computeIfAbsent(detectorId, it -> new ConcurrentHashMap<>()); + return states.computeIfAbsent(detectorTaskId, it -> new ADHCBatchTaskRunState()); + } + + public String getDetectorTaskState(String detectorId, String detectorTaskId) { + ADHCBatchTaskRunState batchTaskRunStates = getHCBatchTaskRunState(detectorId, detectorTaskId); + if (batchTaskRunStates != null) { + return batchTaskRunStates.getDetectorTaskState(); + } + return null; + } + + public boolean detectorTaskStateExists(String detectorId, String detectorTaskId) { + Map taskStateCache = hcBatchTaskRunState.get(detectorId); + return taskStateCache != null && taskStateCache.containsKey(detectorTaskId); + } + + private ADHCBatchTaskRunState getHCBatchTaskRunState(String detectorId, String detectorTaskId) { + if (detectorId == null || detectorTaskId == null) { + return null; + } + Map batchTaskRunStates = hcBatchTaskRunState.get(detectorId); + if (batchTaskRunStates != null) { + return batchTaskRunStates.get(detectorTaskId); + } + return null; + } + + /** + * Check if HC detector's historical analysis cancelled or not. + * + * @param detectorId detector id + * @param detectorTaskId detector level task id + * @return true if HC detector historical analysis cancelled; otherwise return false + */ + public boolean isHistoricalAnalysisCancelledForHC(String detectorId, String detectorTaskId) { + ADHCBatchTaskRunState taskStateCache = getHCBatchTaskRunState(detectorId, detectorTaskId); + if (taskStateCache != null) { + return taskStateCache.getHistoricalAnalysisCancelled(); + } + return false; + } + + /** + * Get why task cancelled. + * + * @param taskId AD task id + * @return task cancellation reason + */ + public String getCancelReason(String taskId) { + return getBatchTaskCache(taskId).getCancelReason(); + } + + /** + * Get task cancelled by which user. + * + * @param taskId AD task id + * @return user name + */ + public String getCancelledBy(String taskId) { + return getBatchTaskCache(taskId).getCancelledBy(); + } + + public String getCancelledByForHC(String detectorId, String detectorTaskId) { + ADHCBatchTaskRunState taskCache = getHCBatchTaskRunState(detectorId, detectorTaskId); + if (taskCache != null) { + return taskCache.getCancelledBy(); + } + return null; + } + + public String getCancelReasonForHC(String detectorId, String detectorTaskId) { + ADHCBatchTaskRunState taskCache = getHCBatchTaskRunState(detectorId, detectorTaskId); + if (taskCache != null) { + return taskCache.getCancelReason(); + } + return null; + } + + public void cleanExpiredHCBatchTaskRunStates() { + if (!cleanExpiredHCBatchTaskRunStatesSemaphore.tryAcquire()) { + return; + } + try { + List detectorIdOfEmptyStates = new ArrayList<>(); + for (Map.Entry> detectorRunStates : hcBatchTaskRunState.entrySet()) { + List taskIdOfExpiredStates = new ArrayList<>(); + String detectorId = detectorRunStates.getKey(); + boolean noRunningTask = isNullOrEmpty(getTasksOfDetector(detectorId)); + Map taskRunStates = detectorRunStates.getValue(); + if (taskRunStates == null) { + // If detector's task run state is null, add detector id to detectorIdOfEmptyStates and remove it from + // hcBatchTaskRunState later. + detectorIdOfEmptyStates.add(detectorId); + continue; + } + if (!noRunningTask) { + // If a detector has running task, we should not clean up task run state cache for it. + // It's possible that some entity task is on the way to worker node. So we should not + // remove detector level state if no running task found. Otherwise the task may arrive + // after run state cache deleted, then it can run on work node. We should delete cache + // if no running task and run state expired. + continue; + } + for (Map.Entry taskRunState : taskRunStates.entrySet()) { + ADHCBatchTaskRunState state = taskRunState.getValue(); + if (state != null && noRunningTask && state.expired()) { + taskIdOfExpiredStates.add(taskRunState.getKey()); + } + } + logger + .debug( + "Remove expired HC batch task run states for these tasks: {}", + Arrays.toString(taskIdOfExpiredStates.toArray(new String[0])) + ); + taskIdOfExpiredStates.forEach(id -> taskRunStates.remove(id)); + if (taskRunStates.isEmpty()) { + detectorIdOfEmptyStates.add(detectorId); + } + } + logger + .debug( + "Remove empty HC batch task run states for these detectors : {}", + Arrays.toString(detectorIdOfEmptyStates.toArray(new String[0])) + ); + detectorIdOfEmptyStates.forEach(id -> hcBatchTaskRunState.remove(id)); + } catch (Exception e) { + logger.error("Failed to clean expired HC batch task run states", e); + } finally { + cleanExpiredHCBatchTaskRunStatesSemaphore.release(); + } + } + + /** + * We query result index to check if there are any result generated for detector to tell whether it passed initialization of not. + * To avoid repeated query when there is no data, record whether we have done that or not. + * @param id detector id + */ + public void markResultIndexQueried(String id) { + ADRealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(id); + // we initialize a real time cache at the beginning of AnomalyResultTransportAction if it + // cannot be found. If the cache is empty, we will return early and wait it for it to be + // initialized. + if (realtimeTaskCache != null) { + realtimeTaskCache.setQueriedResultIndex(true); + } + } + + /** + * We query result index to check if there are any result generated for detector to tell whether it passed initialization of not. + * + * @param id detector id + * @return whether we have queried result index or not. + */ + public boolean hasQueriedResultIndex(String id) { + ADRealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(id); + if (realtimeTaskCache != null) { + return realtimeTaskCache.hasQueriedResultIndex(); + } + return false; + } +} diff --git a/src/main/java/org/opensearch/ad/task/ADTaskCancellationState.java-e b/src/main/java/org/opensearch/ad/task/ADTaskCancellationState.java-e new file mode 100644 index 000000000..c16ba0a2f --- /dev/null +++ b/src/main/java/org/opensearch/ad/task/ADTaskCancellationState.java-e @@ -0,0 +1,18 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.task; + +public enum ADTaskCancellationState { + NOT_FOUND, + CANCELLED, + ALREADY_CANCELLED +} diff --git a/src/main/java/org/opensearch/ad/task/ADTaskManager.java b/src/main/java/org/opensearch/ad/task/ADTaskManager.java index bed979656..c482b0ba8 100644 --- a/src/main/java/org/opensearch/ad/task/ADTaskManager.java +++ b/src/main/java/org/opensearch/ad/task/ADTaskManager.java @@ -12,7 +12,6 @@ package org.opensearch.ad.task; import static org.opensearch.action.DocWriteResponse.Result.CREATED; -import static org.opensearch.ad.AnomalyDetectorPlugin.AD_BATCH_TASK_THREAD_POOL_NAME; import static org.opensearch.ad.constant.ADCommonMessages.CAN_NOT_FIND_LATEST_TASK; import static org.opensearch.ad.constant.ADCommonMessages.DETECTOR_IS_RUNNING; import static org.opensearch.ad.constant.ADCommonMessages.EXCEED_HISTORICAL_ANALYSIS_LIMIT; @@ -51,7 +50,8 @@ import static org.opensearch.ad.stats.InternalStatNames.AD_USED_BATCH_TASK_SLOT_COUNT; import static org.opensearch.ad.util.ExceptionUtil.getErrorMessage; import static org.opensearch.ad.util.ExceptionUtil.getShardsFailure; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_BATCH_TASK_THREAD_POOL_NAME; import static org.opensearch.timeseries.constant.CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED; import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; import static org.opensearch.timeseries.constant.CommonName.TASK_ID_FIELD; @@ -128,7 +128,6 @@ import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; @@ -136,6 +135,8 @@ import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -150,7 +151,6 @@ import org.opensearch.index.reindex.DeleteByQueryRequest; import org.opensearch.index.reindex.UpdateByQueryAction; import org.opensearch.index.reindex.UpdateByQueryRequest; -import org.opensearch.rest.RestStatus; import org.opensearch.script.Script; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; diff --git a/src/main/java/org/opensearch/ad/task/ADTaskManager.java-e b/src/main/java/org/opensearch/ad/task/ADTaskManager.java-e new file mode 100644 index 000000000..9fd5ed23c --- /dev/null +++ b/src/main/java/org/opensearch/ad/task/ADTaskManager.java-e @@ -0,0 +1,3081 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.task; + +import static org.opensearch.action.DocWriteResponse.Result.CREATED; +import static org.opensearch.ad.constant.ADCommonMessages.CAN_NOT_FIND_LATEST_TASK; +import static org.opensearch.ad.constant.ADCommonMessages.DETECTOR_IS_RUNNING; +import static org.opensearch.ad.constant.ADCommonMessages.EXCEED_HISTORICAL_ANALYSIS_LIMIT; +import static org.opensearch.ad.constant.ADCommonMessages.HC_DETECTOR_TASK_IS_UPDATING; +import static org.opensearch.ad.constant.ADCommonMessages.NO_ELIGIBLE_NODE_TO_RUN_DETECTOR; +import static org.opensearch.ad.constant.ADCommonName.DETECTION_STATE_INDEX; +import static org.opensearch.ad.indices.ADIndexManagement.ALL_AD_RESULTS_INDEX_PATTERN; +import static org.opensearch.ad.model.ADTask.COORDINATING_NODE_FIELD; +import static org.opensearch.ad.model.ADTask.DETECTOR_ID_FIELD; +import static org.opensearch.ad.model.ADTask.ERROR_FIELD; +import static org.opensearch.ad.model.ADTask.ESTIMATED_MINUTES_LEFT_FIELD; +import static org.opensearch.ad.model.ADTask.EXECUTION_END_TIME_FIELD; +import static org.opensearch.ad.model.ADTask.EXECUTION_START_TIME_FIELD; +import static org.opensearch.ad.model.ADTask.INIT_PROGRESS_FIELD; +import static org.opensearch.ad.model.ADTask.IS_LATEST_FIELD; +import static org.opensearch.ad.model.ADTask.LAST_UPDATE_TIME_FIELD; +import static org.opensearch.ad.model.ADTask.PARENT_TASK_ID_FIELD; +import static org.opensearch.ad.model.ADTask.STATE_FIELD; +import static org.opensearch.ad.model.ADTask.STOPPED_BY_FIELD; +import static org.opensearch.ad.model.ADTask.TASK_PROGRESS_FIELD; +import static org.opensearch.ad.model.ADTask.TASK_TYPE_FIELD; +import static org.opensearch.ad.model.ADTaskState.NOT_ENDED_STATES; +import static org.opensearch.ad.model.ADTaskType.ALL_HISTORICAL_TASK_TYPES; +import static org.opensearch.ad.model.ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES; +import static org.opensearch.ad.model.ADTaskType.REALTIME_TASK_TYPES; +import static org.opensearch.ad.model.ADTaskType.taskTypeToString; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.DELETE_AD_RESULT_WHEN_DELETE_DETECTOR; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.NUM_MIN_SAMPLES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.REQUEST_TIMEOUT; +import static org.opensearch.ad.stats.InternalStatNames.AD_DETECTOR_ASSIGNED_BATCH_TASK_SLOT_COUNT; +import static org.opensearch.ad.stats.InternalStatNames.AD_USED_BATCH_TASK_SLOT_COUNT; +import static org.opensearch.ad.util.ExceptionUtil.getErrorMessage; +import static org.opensearch.ad.util.ExceptionUtil.getShardsFailure; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_BATCH_TASK_THREAD_POOL_NAME; +import static org.opensearch.timeseries.constant.CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED; +import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; +import static org.opensearch.timeseries.constant.CommonName.TASK_ID_FIELD; +import static org.opensearch.timeseries.util.ParseUtils.isNullOrEmpty; +import static org.opensearch.timeseries.util.RestHandlerUtils.XCONTENT_WITH_TYPE; +import static org.opensearch.timeseries.util.RestHandlerUtils.createXContentParserFromRegistry; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Semaphore; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionListenerResponseHandler; +import org.opensearch.action.bulk.BulkAction; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADEntityTaskProfile; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskAction; +import org.opensearch.ad.model.ADTaskProfile; +import org.opensearch.ad.model.ADTaskState; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.model.DetectorProfile; +import org.opensearch.ad.rest.handler.IndexAnomalyDetectorJobActionHandler; +import org.opensearch.ad.transport.ADBatchAnomalyResultAction; +import org.opensearch.ad.transport.ADBatchAnomalyResultRequest; +import org.opensearch.ad.transport.ADCancelTaskAction; +import org.opensearch.ad.transport.ADCancelTaskRequest; +import org.opensearch.ad.transport.ADStatsNodeResponse; +import org.opensearch.ad.transport.ADStatsNodesAction; +import org.opensearch.ad.transport.ADStatsRequest; +import org.opensearch.ad.transport.ADTaskProfileAction; +import org.opensearch.ad.transport.ADTaskProfileNodeResponse; +import org.opensearch.ad.transport.ADTaskProfileRequest; +import org.opensearch.ad.transport.AnomalyDetectorJobResponse; +import org.opensearch.ad.transport.ForwardADTaskAction; +import org.opensearch.ad.transport.ForwardADTaskRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.NestedQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.query.TermsQueryBuilder; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.index.reindex.UpdateByQueryAction; +import org.opensearch.index.reindex.UpdateByQueryRequest; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.script.Script; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.common.exception.DuplicateTaskException; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TaskCancelledException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportRequestOptions; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +/** + * Manage AD task. + */ +public class ADTaskManager { + public static final String AD_TASK_LEAD_NODE_MODEL_ID = "ad_task_lead_node_model_id"; + public static final String AD_TASK_MAINTAINENCE_NODE_MODEL_ID = "ad_task_maintainence_node_model_id"; + // HC batch task timeout after 10 minutes if no update after last known run time. + public static final int HC_BATCH_TASK_CACHE_TIMEOUT_IN_MILLIS = 600_000; + private final Logger logger = LogManager.getLogger(this.getClass()); + static final String STATE_INDEX_NOT_EXIST_MSG = "State index does not exist."; + private final Set retryableErrors = ImmutableSet.of(EXCEED_HISTORICAL_ANALYSIS_LIMIT, NO_ELIGIBLE_NODE_TO_RUN_DETECTOR); + private final Client client; + private final ClusterService clusterService; + private final NamedXContentRegistry xContentRegistry; + private final ADIndexManagement detectionIndices; + private final DiscoveryNodeFilterer nodeFilter; + private final ADTaskCacheManager adTaskCacheManager; + + private final HashRing hashRing; + private volatile Integer maxOldAdTaskDocsPerDetector; + private volatile Integer pieceIntervalSeconds; + private volatile boolean deleteADResultWhenDeleteDetector; + private volatile TransportRequestOptions transportRequestOptions; + private final ThreadPool threadPool; + private static int DEFAULT_MAINTAIN_INTERVAL_IN_SECONDS = 5; + private final Semaphore checkingTaskSlot; + + private volatile Integer maxAdBatchTaskPerNode; + private volatile Integer maxRunningEntitiesPerDetector; + + private final Semaphore scaleEntityTaskLane; + private static final int SCALE_ENTITY_TASK_LANE_INTERVAL_IN_MILLIS = 10_000; // 10 seconds + + public ADTaskManager( + Settings settings, + ClusterService clusterService, + Client client, + NamedXContentRegistry xContentRegistry, + ADIndexManagement detectionIndices, + DiscoveryNodeFilterer nodeFilter, + HashRing hashRing, + ADTaskCacheManager adTaskCacheManager, + ThreadPool threadPool + ) { + this.client = client; + this.xContentRegistry = xContentRegistry; + this.detectionIndices = detectionIndices; + this.nodeFilter = nodeFilter; + this.clusterService = clusterService; + this.adTaskCacheManager = adTaskCacheManager; + this.hashRing = hashRing; + + this.maxOldAdTaskDocsPerDetector = MAX_OLD_AD_TASK_DOCS_PER_DETECTOR.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(MAX_OLD_AD_TASK_DOCS_PER_DETECTOR, it -> maxOldAdTaskDocsPerDetector = it); + + this.pieceIntervalSeconds = BATCH_TASK_PIECE_INTERVAL_SECONDS.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(BATCH_TASK_PIECE_INTERVAL_SECONDS, it -> pieceIntervalSeconds = it); + + this.deleteADResultWhenDeleteDetector = DELETE_AD_RESULT_WHEN_DELETE_DETECTOR.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(DELETE_AD_RESULT_WHEN_DELETE_DETECTOR, it -> deleteADResultWhenDeleteDetector = it); + + this.maxAdBatchTaskPerNode = MAX_BATCH_TASK_PER_NODE.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_BATCH_TASK_PER_NODE, it -> maxAdBatchTaskPerNode = it); + + this.maxRunningEntitiesPerDetector = MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS, it -> maxRunningEntitiesPerDetector = it); + + transportRequestOptions = TransportRequestOptions + .builder() + .withType(TransportRequestOptions.Type.REG) + .withTimeout(REQUEST_TIMEOUT.get(settings)) + .build(); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer( + REQUEST_TIMEOUT, + it -> { + transportRequestOptions = TransportRequestOptions + .builder() + .withType(TransportRequestOptions.Type.REG) + .withTimeout(it) + .build(); + } + ); + this.threadPool = threadPool; + this.checkingTaskSlot = new Semaphore(1); + this.scaleEntityTaskLane = new Semaphore(1); + } + + /** + * Start detector. Will create schedule job for realtime detector, + * and start AD task for historical detector. + * + * @param detectorId detector id + * @param detectionDateRange historical analysis date range + * @param handler anomaly detector job action handler + * @param user user + * @param transportService transport service + * @param context thread context + * @param listener action listener + */ + public void startDetector( + String detectorId, + DateRange detectionDateRange, + IndexAnomalyDetectorJobActionHandler handler, + User user, + TransportService transportService, + ThreadContext.StoredContext context, + ActionListener listener + ) { + // upgrade index mapping of AD default indices + detectionIndices.update(); + + getDetector(detectorId, (detector) -> { + if (!detector.isPresent()) { + listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, RestStatus.NOT_FOUND)); + return; + } + + // Validate if detector is ready to start. Will return null if ready to start. + String errorMessage = validateDetector(detector.get()); + if (errorMessage != null) { + listener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.BAD_REQUEST)); + return; + } + String resultIndex = detector.get().getCustomResultIndex(); + if (resultIndex == null) { + startRealtimeOrHistoricalDetection(detectionDateRange, handler, user, transportService, listener, detector); + return; + } + context.restore(); + detectionIndices + .initCustomResultIndexAndExecute( + resultIndex, + () -> startRealtimeOrHistoricalDetection(detectionDateRange, handler, user, transportService, listener, detector), + listener + ); + + }, listener); + } + + private void startRealtimeOrHistoricalDetection( + DateRange detectionDateRange, + IndexAnomalyDetectorJobActionHandler handler, + User user, + TransportService transportService, + ActionListener listener, + Optional detector + ) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + if (detectionDateRange == null) { + // start realtime job + handler.startAnomalyDetectorJob(detector.get(), listener); + } else { + // start historical analysis task + forwardApplyForTaskSlotsRequestToLeadNode(detector.get(), detectionDateRange, user, transportService, listener); + } + } catch (Exception e) { + logger.error("Failed to stash context", e); + listener.onFailure(e); + } + } + + /** + * When AD receives start historical analysis request for a detector, will + * 1. Forward to lead node to check available task slots first. + * 2. If available task slots exit, will forward request to coordinating node + * to gather information like top entities. + * 3. Then coordinating node will choose one data node with least load as work + * node and dispatch historical analysis to it. + * + * @param detector detector + * @param detectionDateRange detection date range + * @param user user + * @param transportService transport service + * @param listener action listener + */ + protected void forwardApplyForTaskSlotsRequestToLeadNode( + AnomalyDetector detector, + DateRange detectionDateRange, + User user, + TransportService transportService, + ActionListener listener + ) { + ForwardADTaskRequest forwardADTaskRequest = new ForwardADTaskRequest( + detector, + detectionDateRange, + user, + ADTaskAction.APPLY_FOR_TASK_SLOTS + ); + forwardRequestToLeadNode(forwardADTaskRequest, transportService, listener); + } + + public void forwardScaleTaskSlotRequestToLeadNode( + ADTask adTask, + TransportService transportService, + ActionListener listener + ) { + forwardRequestToLeadNode(new ForwardADTaskRequest(adTask, ADTaskAction.CHECK_AVAILABLE_TASK_SLOTS), transportService, listener); + } + + public void forwardRequestToLeadNode( + ForwardADTaskRequest forwardADTaskRequest, + TransportService transportService, + ActionListener listener + ) { + hashRing.buildAndGetOwningNodeWithSameLocalAdVersion(AD_TASK_LEAD_NODE_MODEL_ID, node -> { + if (!node.isPresent()) { + listener.onFailure(new ResourceNotFoundException("Can't find AD task lead node")); + return; + } + transportService + .sendRequest( + node.get(), + ForwardADTaskAction.NAME, + forwardADTaskRequest, + transportRequestOptions, + new ActionListenerResponseHandler<>(listener, AnomalyDetectorJobResponse::new) + ); + }, listener); + } + + /** + * Forward historical analysis task to coordinating node. + * + * @param detector anomaly detector + * @param detectionDateRange historical analysis date range + * @param user user + * @param availableTaskSlots available task slots + * @param transportService transport service + * @param listener action listener + */ + public void startHistoricalAnalysis( + AnomalyDetector detector, + DateRange detectionDateRange, + User user, + int availableTaskSlots, + TransportService transportService, + ActionListener listener + ) { + String detectorId = detector.getId(); + hashRing.buildAndGetOwningNodeWithSameLocalAdVersion(detectorId, owningNode -> { + if (!owningNode.isPresent()) { + logger.debug("Can't find eligible node to run as AD task's coordinating node"); + listener.onFailure(new OpenSearchStatusException("No eligible node to run detector", RestStatus.INTERNAL_SERVER_ERROR)); + return; + } + logger.debug("coordinating node is : {} for detector: {}", owningNode.get().getId(), detectorId); + forwardDetectRequestToCoordinatingNode( + detector, + detectionDateRange, + user, + availableTaskSlots, + ADTaskAction.START, + transportService, + owningNode.get(), + listener + ); + }, listener); + + } + + /** + * We have three types of nodes in AD task process. + * + * 1.Forwarding node which receives external request. The request will \ + * be sent to coordinating node first. + * 2.Coordinating node which maintains running historical detector set.\ + * We use hash ring to find coordinating node with detector id. \ + * Coordinating node will find a worker node with least load and \ + * dispatch AD task to that worker node. + * 3.Worker node which will run AD task. + * + * This function is to forward the request to coordinating node. + * + * @param detector anomaly detector + * @param detectionDateRange historical analysis date range + * @param user user + * @param availableTaskSlots available task slots + * @param adTaskAction AD task action + * @param transportService transport service + * @param node ES node + * @param listener action listener + */ + protected void forwardDetectRequestToCoordinatingNode( + AnomalyDetector detector, + DateRange detectionDateRange, + User user, + Integer availableTaskSlots, + ADTaskAction adTaskAction, + TransportService transportService, + DiscoveryNode node, + ActionListener listener + ) { + Version adVersion = hashRing.getAdVersion(node.getId()); + transportService + .sendRequest( + node, + ForwardADTaskAction.NAME, + // We need to check AD version of remote node as we may send clean detector cache request to old + // node, check ADTaskManager#cleanDetectorCache. + new ForwardADTaskRequest(detector, detectionDateRange, user, adTaskAction, availableTaskSlots, adVersion), + transportRequestOptions, + new ActionListenerResponseHandler<>(listener, AnomalyDetectorJobResponse::new) + ); + } + + /** + * Forward AD task to coordinating node + * + * @param adTask AD task + * @param adTaskAction AD task action + * @param transportService transport service + * @param listener action listener + */ + protected void forwardADTaskToCoordinatingNode( + ADTask adTask, + ADTaskAction adTaskAction, + TransportService transportService, + ActionListener listener + ) { + logger.debug("Forward AD task to coordinating node, task id: {}, action: {}", adTask.getTaskId(), adTaskAction.name()); + transportService + .sendRequest( + getCoordinatingNode(adTask), + ForwardADTaskAction.NAME, + new ForwardADTaskRequest(adTask, adTaskAction), + transportRequestOptions, + new ActionListenerResponseHandler<>(listener, AnomalyDetectorJobResponse::new) + ); + } + + /** + * Forward stale running entities to coordinating node to clean up. + * + * @param adTask AD task + * @param adTaskAction AD task action + * @param transportService transport service + * @param staleRunningEntity stale running entities + * @param listener action listener + */ + protected void forwardStaleRunningEntitiesToCoordinatingNode( + ADTask adTask, + ADTaskAction adTaskAction, + TransportService transportService, + List staleRunningEntity, + ActionListener listener + ) { + transportService + .sendRequest( + getCoordinatingNode(adTask), + ForwardADTaskAction.NAME, + new ForwardADTaskRequest(adTask, adTaskAction, staleRunningEntity), + transportRequestOptions, + new ActionListenerResponseHandler<>(listener, AnomalyDetectorJobResponse::new) + ); + } + + /** + * Check available task slots before start historical analysis and scale task lane. + * This check will be done on lead node which will gather detector task slots of all + * data nodes and calculate how many task slots available. + * + * @param adTask AD task + * @param detector detector + * @param detectionDateRange detection date range + * @param user user + * @param afterCheckAction target task action to run after task slot checking + * @param transportService transport service + * @param listener action listener + */ + public void checkTaskSlots( + ADTask adTask, + AnomalyDetector detector, + DateRange detectionDateRange, + User user, + ADTaskAction afterCheckAction, + TransportService transportService, + ActionListener listener + ) { + String detectorId = detector.getId(); + logger.debug("Start checking task slots for detector: {}, task action: {}", detectorId, afterCheckAction); + if (!checkingTaskSlot.tryAcquire()) { + logger.info("Can't acquire checking task slot semaphore for detector {}", detectorId); + listener + .onFailure( + new OpenSearchStatusException( + "Too many historical analysis requests in short time. Please retry later.", + RestStatus.FORBIDDEN + ) + ); + return; + } + ActionListener wrappedActionListener = ActionListener.runAfter(listener, () -> { + checkingTaskSlot.release(1); + logger.debug("Release checking task slot semaphore on lead node for detector {}", detectorId); + }); + hashRing.getNodesWithSameLocalAdVersion(nodes -> { + int maxAdTaskSlots = nodes.length * maxAdBatchTaskPerNode; + ADStatsRequest adStatsRequest = new ADStatsRequest(nodes); + adStatsRequest + .addAll(ImmutableSet.of(AD_USED_BATCH_TASK_SLOT_COUNT.getName(), AD_DETECTOR_ASSIGNED_BATCH_TASK_SLOT_COUNT.getName())); + client.execute(ADStatsNodesAction.INSTANCE, adStatsRequest, ActionListener.wrap(adStatsResponse -> { + int totalUsedTaskSlots = 0; // Total entity tasks running on worker nodes + int totalAssignedTaskSlots = 0; // Total assigned task slots on coordinating nodes + for (ADStatsNodeResponse response : adStatsResponse.getNodes()) { + totalUsedTaskSlots += (int) response.getStatsMap().get(AD_USED_BATCH_TASK_SLOT_COUNT.getName()); + totalAssignedTaskSlots += (int) response.getStatsMap().get(AD_DETECTOR_ASSIGNED_BATCH_TASK_SLOT_COUNT.getName()); + } + logger + .info( + "Current total used task slots is {}, total detector assigned task slots is {} when start historical " + + "analysis for detector {}", + totalUsedTaskSlots, + totalAssignedTaskSlots, + detectorId + ); + // In happy case, totalAssignedTaskSlots >= totalUsedTaskSlots. If some coordinating node left, then we can't + // get detector task slots cached on it, so it's possible that totalAssignedTaskSlots < totalUsedTaskSlots. + int currentUsedTaskSlots = Math.max(totalUsedTaskSlots, totalAssignedTaskSlots); + if (currentUsedTaskSlots >= maxAdTaskSlots) { + wrappedActionListener.onFailure(new OpenSearchStatusException("No available task slot", RestStatus.BAD_REQUEST)); + return; + } + int availableAdTaskSlots = maxAdTaskSlots - currentUsedTaskSlots; + logger.info("Current available task slots is {} for historical analysis of detector {}", availableAdTaskSlots, detectorId); + + if (ADTaskAction.SCALE_ENTITY_TASK_SLOTS == afterCheckAction) { + forwardToCoordinatingNode( + adTask, + detector, + detectionDateRange, + user, + afterCheckAction, + transportService, + wrappedActionListener, + availableAdTaskSlots + ); + return; + } + + // It takes long time to check top entities especially for multi-category HC. Tested with + // 1.8 billion docs for multi-category HC, it took more than 20 seconds and caused timeout. + // By removing top entity check, it took about 200ms to return. So just remove it to make + // sure REST API can return quickly. + // We may assign more task slots. For example, cluster has 4 data nodes, each node can run 2 + // batch tasks, so the available task slot number is 8. If max running entities per HC is 4, + // then we will assign 4 tasks slots to this HC detector (4 is less than 8). The data index + // only has 2 entities. So we assign 2 more task slots than actual need. But it's ok as we + // will auto tune task slot when historical analysis task starts. + int approvedTaskSlots = detector.isHighCardinality() ? Math.min(maxRunningEntitiesPerDetector, availableAdTaskSlots) : 1; + forwardToCoordinatingNode( + adTask, + detector, + detectionDateRange, + user, + afterCheckAction, + transportService, + wrappedActionListener, + approvedTaskSlots + ); + }, exception -> { + logger.error("Failed to get node's task stats for detector " + detectorId, exception); + wrappedActionListener.onFailure(exception); + })); + }, wrappedActionListener); + } + + private void forwardToCoordinatingNode( + ADTask adTask, + AnomalyDetector detector, + DateRange detectionDateRange, + User user, + ADTaskAction targetActionOfTaskSlotChecking, + TransportService transportService, + ActionListener wrappedActionListener, + int approvedTaskSlots + ) { + switch (targetActionOfTaskSlotChecking) { + case START: + logger.info("Will assign {} task slots to run historical analysis for detector {}", approvedTaskSlots, detector.getId()); + startHistoricalAnalysis(detector, detectionDateRange, user, approvedTaskSlots, transportService, wrappedActionListener); + break; + case SCALE_ENTITY_TASK_SLOTS: + logger + .info( + "There are {} task slots available now to scale historical analysis task lane for detector {}", + approvedTaskSlots, + adTask.getId() + ); + scaleTaskLaneOnCoordinatingNode(adTask, approvedTaskSlots, transportService, wrappedActionListener); + break; + default: + wrappedActionListener.onFailure(new TimeSeriesException("Unknown task action " + targetActionOfTaskSlotChecking)); + break; + } + } + + protected void scaleTaskLaneOnCoordinatingNode( + ADTask adTask, + int approvedTaskSlot, + TransportService transportService, + ActionListener listener + ) { + DiscoveryNode coordinatingNode = getCoordinatingNode(adTask); + transportService + .sendRequest( + coordinatingNode, + ForwardADTaskAction.NAME, + new ForwardADTaskRequest(adTask, approvedTaskSlot, ADTaskAction.SCALE_ENTITY_TASK_SLOTS), + transportRequestOptions, + new ActionListenerResponseHandler<>(listener, AnomalyDetectorJobResponse::new) + ); + } + + private DiscoveryNode getCoordinatingNode(ADTask adTask) { + String coordinatingNode = adTask.getCoordinatingNode(); + DiscoveryNode[] eligibleDataNodes = nodeFilter.getEligibleDataNodes(); + DiscoveryNode targetNode = null; + for (DiscoveryNode node : eligibleDataNodes) { + if (node.getId().equals(coordinatingNode)) { + targetNode = node; + break; + } + } + if (targetNode == null) { + throw new ResourceNotFoundException(adTask.getId(), "AD task coordinating node not found"); + } + return targetNode; + } + + /** + * Start anomaly detector. + * For historical analysis, this method will be called on coordinating node. + * For realtime task, we won't know AD job coordinating node until AD job starts. So + * this method will be called on vanilla node. + * + * Will init task index if not exist and write new AD task to index. If task index + * exists, will check if there is task running. If no running task, reset old task + * as not latest and clean old tasks which exceeds max old task doc limitation. + * Then find out node with least load and dispatch task to that node(worker node). + * + * @param detector anomaly detector + * @param detectionDateRange detection date range + * @param user user + * @param transportService transport service + * @param listener action listener + */ + public void startDetector( + AnomalyDetector detector, + DateRange detectionDateRange, + User user, + TransportService transportService, + ActionListener listener + ) { + try { + if (detectionIndices.doesStateIndexExist()) { + // If detection index exist, check if latest AD task is running + getAndExecuteOnLatestDetectorLevelTask(detector.getId(), getADTaskTypes(detectionDateRange), (adTask) -> { + if (!adTask.isPresent() || adTask.get().isDone()) { + updateLatestFlagOfOldTasksAndCreateNewTask(detector, detectionDateRange, user, listener); + } else { + listener.onFailure(new OpenSearchStatusException(DETECTOR_IS_RUNNING, RestStatus.BAD_REQUEST)); + } + }, transportService, true, listener); + } else { + // If detection index doesn't exist, create index and execute detector. + detectionIndices.initStateIndex(ActionListener.wrap(r -> { + if (r.isAcknowledged()) { + logger.info("Created {} with mappings.", DETECTION_STATE_INDEX); + updateLatestFlagOfOldTasksAndCreateNewTask(detector, detectionDateRange, user, listener); + } else { + String error = String.format(Locale.ROOT, CREATE_INDEX_NOT_ACKNOWLEDGED, DETECTION_STATE_INDEX); + logger.warn(error); + listener.onFailure(new OpenSearchStatusException(error, RestStatus.INTERNAL_SERVER_ERROR)); + } + }, e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException) { + updateLatestFlagOfOldTasksAndCreateNewTask(detector, detectionDateRange, user, listener); + } else { + logger.error("Failed to init anomaly detection state index", e); + listener.onFailure(e); + } + })); + } + } catch (Exception e) { + logger.error("Failed to start detector " + detector.getId(), e); + listener.onFailure(e); + } + } + + private ADTaskType getADTaskType(AnomalyDetector detector, DateRange detectionDateRange) { + if (detectionDateRange == null) { + return detector.isHighCardinality() ? ADTaskType.REALTIME_HC_DETECTOR : ADTaskType.REALTIME_SINGLE_ENTITY; + } else { + return detector.isHighCardinality() ? ADTaskType.HISTORICAL_HC_DETECTOR : ADTaskType.HISTORICAL_SINGLE_ENTITY; + } + } + + private List getADTaskTypes(DateRange detectionDateRange) { + return getADTaskTypes(detectionDateRange, false); + } + + /** + * Get list of task types. + * 1. If detection date range is null, will return all realtime task types + * 2. If detection date range is not null, will return all historical detector level tasks types + * if resetLatestTaskStateFlag is true; otherwise return all historical tasks types include + * HC entity level task type. + * @param detectionDateRange detection date range + * @param resetLatestTaskStateFlag reset latest task state or not + * @return list of AD task types + */ + private List getADTaskTypes(DateRange detectionDateRange, boolean resetLatestTaskStateFlag) { + if (detectionDateRange == null) { + return REALTIME_TASK_TYPES; + } else { + if (resetLatestTaskStateFlag) { + // return all task types include HC entity task to make sure we can reset all tasks latest flag + return ALL_HISTORICAL_TASK_TYPES; + } else { + return HISTORICAL_DETECTOR_TASK_TYPES; + } + } + } + + /** + * Stop detector. + * For realtime detector, will set detector job as disabled. + * For historical detector, will set its AD task as cancelled. + * + * @param detectorId detector id + * @param historical stop historical analysis or not + * @param handler AD job action handler + * @param user user + * @param transportService transport service + * @param listener action listener + */ + public void stopDetector( + String detectorId, + boolean historical, + IndexAnomalyDetectorJobActionHandler handler, + User user, + TransportService transportService, + ActionListener listener + ) { + getDetector(detectorId, (detector) -> { + if (!detector.isPresent()) { + listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, RestStatus.NOT_FOUND)); + return; + } + if (historical) { + // stop historical analyis + getAndExecuteOnLatestDetectorLevelTask( + detectorId, + HISTORICAL_DETECTOR_TASK_TYPES, + (task) -> stopHistoricalAnalysis(detectorId, task, user, listener), + transportService, + false,// don't need to reset task state when stop detector + listener + ); + } else { + // stop realtime detector job + handler.stopAnomalyDetectorJob(detectorId, listener); + } + }, listener); + } + + /** + * Get anomaly detector and execute consumer function. + * [Important!] Make sure listener returns in function + * + * @param detectorId detector id + * @param function consumer function + * @param listener action listener + * @param action listener response type + */ + public void getDetector(String detectorId, Consumer> function, ActionListener listener) { + GetRequest getRequest = new GetRequest(CommonName.CONFIG_INDEX, detectorId); + client.get(getRequest, ActionListener.wrap(response -> { + if (!response.isExists()) { + function.accept(Optional.empty()); + return; + } + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + AnomalyDetector detector = AnomalyDetector.parse(parser, response.getId(), response.getVersion()); + + function.accept(Optional.of(detector)); + } catch (Exception e) { + String message = "Failed to parse anomaly detector " + detectorId; + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + }, exception -> { + logger.error("Failed to get detector " + detectorId, exception); + listener.onFailure(exception); + })); + } + + /** + * Get latest AD task and execute consumer function. + * [Important!] Make sure listener returns in function + * + * @param detectorId detector id + * @param adTaskTypes AD task types + * @param function consumer function + * @param transportService transport service + * @param resetTaskState reset task state or not + * @param listener action listener + * @param action listener response type + */ + public void getAndExecuteOnLatestDetectorLevelTask( + String detectorId, + List adTaskTypes, + Consumer> function, + TransportService transportService, + boolean resetTaskState, + ActionListener listener + ) { + getAndExecuteOnLatestADTask(detectorId, null, null, adTaskTypes, function, transportService, resetTaskState, listener); + } + + /** + * Get one latest AD task and execute consumer function. + * [Important!] Make sure listener returns in function + * + * @param detectorId detector id + * @param parentTaskId parent task id + * @param entity entity value + * @param adTaskTypes AD task types + * @param function consumer function + * @param transportService transport service + * @param resetTaskState reset task state or not + * @param listener action listener + * @param action listener response type + */ + public void getAndExecuteOnLatestADTask( + String detectorId, + String parentTaskId, + Entity entity, + List adTaskTypes, + Consumer> function, + TransportService transportService, + boolean resetTaskState, + ActionListener listener + ) { + getAndExecuteOnLatestADTasks(detectorId, parentTaskId, entity, adTaskTypes, (taskList) -> { + if (taskList != null && taskList.size() > 0) { + function.accept(Optional.ofNullable(taskList.get(0))); + } else { + function.accept(Optional.empty()); + } + }, transportService, resetTaskState, 1, listener); + } + + /** + * Get latest AD tasks and execute consumer function. + * If resetTaskState is true, will collect latest task's profile data from all data nodes. If no data + * node running the latest task, will reset the task state as STOPPED; otherwise, check if there is + * any stale running entities(entity exists in coordinating node cache but no task running on worker + * node) and clean up. + * [Important!] Make sure listener returns in function + * + * @param detectorId detector id + * @param parentTaskId parent task id + * @param entity entity value + * @param adTaskTypes AD task types + * @param function consumer function + * @param transportService transport service + * @param resetTaskState reset task state or not + * @param size return how many AD tasks + * @param listener action listener + * @param response type of action listener + */ + public void getAndExecuteOnLatestADTasks( + String detectorId, + String parentTaskId, + Entity entity, + List adTaskTypes, + Consumer> function, + TransportService transportService, + boolean resetTaskState, + int size, + ActionListener listener + ) { + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, detectorId)); + query.filter(new TermQueryBuilder(IS_LATEST_FIELD, true)); + if (parentTaskId != null) { + query.filter(new TermQueryBuilder(PARENT_TASK_ID_FIELD, parentTaskId)); + } + if (adTaskTypes != null && adTaskTypes.size() > 0) { + query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(adTaskTypes))); + } + if (entity != null && !isNullOrEmpty(entity.getAttributes())) { + String path = "entity"; + String entityKeyFieldName = path + ".name"; + String entityValueFieldName = path + ".value"; + + for (Map.Entry attribute : entity.getAttributes().entrySet()) { + BoolQueryBuilder entityBoolQuery = new BoolQueryBuilder(); + TermQueryBuilder entityKeyFilterQuery = QueryBuilders.termQuery(entityKeyFieldName, attribute.getKey()); + TermQueryBuilder entityValueFilterQuery = QueryBuilders.termQuery(entityValueFieldName, attribute.getValue()); + + entityBoolQuery.filter(entityKeyFilterQuery).filter(entityValueFilterQuery); + NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(path, entityBoolQuery, ScoreMode.None); + query.filter(nestedQueryBuilder); + } + } + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(query).sort(EXECUTION_START_TIME_FIELD, SortOrder.DESC).size(size); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(sourceBuilder); + searchRequest.indices(DETECTION_STATE_INDEX); + + client.search(searchRequest, ActionListener.wrap(r -> { + // https://github.com/opendistro-for-elasticsearch/anomaly-detection/pull/359#discussion_r558653132 + // getTotalHits will be null when we track_total_hits is false in the query request. + // Add more checking here to cover some unknown cases. + List adTasks = new ArrayList<>(); + if (r == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { + // don't throw exception here as consumer functions need to handle missing task + // in different way. + function.accept(adTasks); + return; + } + + Iterator iterator = r.getHits().iterator(); + while (iterator.hasNext()) { + SearchHit searchHit = iterator.next(); + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, searchHit.getSourceRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + ADTask adTask = ADTask.parse(parser, searchHit.getId()); + adTasks.add(adTask); + } catch (Exception e) { + String message = "Failed to parse AD task for detector " + detectorId + ", task id " + searchHit.getId(); + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + } + if (resetTaskState) { + resetLatestDetectorTaskState(adTasks, function, transportService, listener); + } else { + function.accept(adTasks); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + function.accept(new ArrayList<>()); + } else { + logger.error("Failed to search AD task for detector " + detectorId, e); + listener.onFailure(e); + } + })); + } + + /** + * Reset latest detector task state. Will reset both historical and realtime tasks. + * [Important!] Make sure listener returns in function + * + * @param adTasks ad tasks + * @param function consumer function + * @param transportService transport service + * @param listener action listener + * @param response type of action listener + */ + private void resetLatestDetectorTaskState( + List adTasks, + Consumer> function, + TransportService transportService, + ActionListener listener + ) { + List runningHistoricalTasks = new ArrayList<>(); + List runningRealtimeTasks = new ArrayList<>(); + for (ADTask adTask : adTasks) { + if (!adTask.isEntityTask() && !adTask.isDone()) { + if (!adTask.isHistoricalTask()) { + // try to reset task state if realtime task is not ended + runningRealtimeTasks.add(adTask); + } else { + // try to reset task state if historical task not updated for 2 piece intervals + runningHistoricalTasks.add(adTask); + } + } + } + + resetHistoricalDetectorTaskState( + runningHistoricalTasks, + () -> resetRealtimeDetectorTaskState(runningRealtimeTasks, () -> function.accept(adTasks), transportService, listener), + transportService, + listener + ); + } + + private void resetRealtimeDetectorTaskState( + List runningRealtimeTasks, + ExecutorFunction function, + TransportService transportService, + ActionListener listener + ) { + if (isNullOrEmpty(runningRealtimeTasks)) { + function.execute(); + return; + } + ADTask adTask = runningRealtimeTasks.get(0); + String detectorId = adTask.getId(); + GetRequest getJobRequest = new GetRequest(CommonName.JOB_INDEX).id(detectorId); + client.get(getJobRequest, ActionListener.wrap(r -> { + if (r.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + AnomalyDetectorJob job = AnomalyDetectorJob.parse(parser); + if (!job.isEnabled()) { + logger.debug("AD job is disabled, reset realtime task as stopped for detector {}", detectorId); + resetTaskStateAsStopped(adTask, function, transportService, listener); + } else { + function.execute(); + } + } catch (IOException e) { + logger.error(" Failed to parse AD job " + detectorId, e); + listener.onFailure(e); + } + } else { + logger.debug("AD job is not found, reset realtime task as stopped for detector {}", detectorId); + resetTaskStateAsStopped(adTask, function, transportService, listener); + } + }, e -> { + logger.error("Fail to get AD realtime job for detector " + detectorId, e); + listener.onFailure(e); + })); + } + + private void resetHistoricalDetectorTaskState( + List runningHistoricalTasks, + ExecutorFunction function, + TransportService transportService, + ActionListener listener + ) { + if (isNullOrEmpty(runningHistoricalTasks)) { + function.execute(); + return; + } + ADTask adTask = runningHistoricalTasks.get(0); + // If AD task is still running, but its last updated time not refreshed for 2 piece intervals, we will get + // task profile to check if it's really running. If task not running, reset state as STOPPED. + // For example, ES process crashes, then all tasks running on it will stay as running. We can reset the task + // state when get historical task with get detector API. + if (!lastUpdateTimeOfHistoricalTaskExpired(adTask)) { + function.execute(); + return; + } + String taskId = adTask.getTaskId(); + AnomalyDetector detector = adTask.getDetector(); + getADTaskProfile(adTask, ActionListener.wrap(taskProfile -> { + boolean taskStopped = isTaskStopped(taskId, detector, taskProfile); + if (taskStopped) { + logger.debug("Reset task state as stopped, task id: {}", adTask.getTaskId()); + if (taskProfile.getTaskId() == null // This means coordinating node doesn't have HC detector cache + && detector.isHighCardinality() + && !isNullOrEmpty(taskProfile.getEntityTaskProfiles())) { + // If coordinating node restarted, HC detector cache on it will be gone. But worker node still + // runs entity tasks, we'd better stop these entity tasks to clean up resource earlier. + stopHistoricalAnalysis(adTask.getId(), Optional.of(adTask), null, ActionListener.wrap(r -> { + logger.debug("Restop detector successfully"); + resetTaskStateAsStopped(adTask, function, transportService, listener); + }, e -> { + logger.error("Failed to restop detector ", e); + listener.onFailure(e); + })); + } else { + resetTaskStateAsStopped(adTask, function, transportService, listener); + } + } else { + function.execute(); + // If still running, check if there is any stale running entities and clean them + if (ADTaskType.HISTORICAL_HC_DETECTOR.name().equals(adTask.getTaskType())) { + // Check if any running entity not run on worker node. If yes, we need to remove it + // and poll next entity from pending entity queue and run it. + if (!isNullOrEmpty(taskProfile.getRunningEntities()) && hcBatchTaskExpired(taskProfile.getLatestHCTaskRunTime())) { + List runningTasksInCoordinatingNodeCache = new ArrayList<>(taskProfile.getRunningEntities()); + List runningTasksOnWorkerNode = new ArrayList<>(); + if (taskProfile.getEntityTaskProfiles() != null && taskProfile.getEntityTaskProfiles().size() > 0) { + taskProfile + .getEntityTaskProfiles() + .forEach(entryTask -> runningTasksOnWorkerNode.add(convertEntityToString(entryTask.getEntity(), detector))); + } + + if (runningTasksInCoordinatingNodeCache.size() > runningTasksOnWorkerNode.size()) { + runningTasksInCoordinatingNodeCache.removeAll(runningTasksOnWorkerNode); + forwardStaleRunningEntitiesToCoordinatingNode( + adTask, + ADTaskAction.CLEAN_STALE_RUNNING_ENTITIES, + transportService, + runningTasksInCoordinatingNodeCache, + ActionListener + .wrap( + res -> logger.debug("Forwarded task to clean stale running entity, task id {}", taskId), + ex -> logger.error("Failed to forward clean stale running entity for task " + taskId, ex) + ) + ); + } + } + } + } + }, e -> { + logger.error("Failed to get AD task profile for task " + adTask.getTaskId(), e); + function.execute(); + })); + } + + private boolean isTaskStopped(String taskId, AnomalyDetector detector, ADTaskProfile taskProfile) { + String detectorId = detector.getId(); + if (taskProfile == null || !Objects.equals(taskId, taskProfile.getTaskId())) { + logger.debug("AD task not found for task {} detector {}", taskId, detectorId); + // If no node is running this task, reset it as STOPPED. + return true; + } + if (!detector.isHighCardinality() && taskProfile.getNodeId() == null) { + logger.debug("AD task not running for single entity detector {}, task {}", detectorId, taskId); + return true; + } + if (detector.isHighCardinality() + && taskProfile.getTotalEntitiesInited() + && isNullOrEmpty(taskProfile.getRunningEntities()) + && isNullOrEmpty(taskProfile.getEntityTaskProfiles()) + && hcBatchTaskExpired(taskProfile.getLatestHCTaskRunTime())) { + logger.debug("AD task not running for HC detector {}, task {}", detectorId, taskId); + return true; + } + return false; + } + + public boolean hcBatchTaskExpired(Long latestHCTaskRunTime) { + if (latestHCTaskRunTime == null) { + return true; + } + return latestHCTaskRunTime + HC_BATCH_TASK_CACHE_TIMEOUT_IN_MILLIS < Instant.now().toEpochMilli(); + } + + private void stopHistoricalAnalysis( + String detectorId, + Optional adTask, + User user, + ActionListener listener + ) { + if (!adTask.isPresent()) { + listener.onFailure(new ResourceNotFoundException(detectorId, "Detector not started")); + return; + } + + if (adTask.get().isDone()) { + listener.onFailure(new ResourceNotFoundException(detectorId, "No running task found")); + return; + } + + String taskId = adTask.get().getTaskId(); + DiscoveryNode[] dataNodes = hashRing.getNodesWithSameLocalAdVersion(); + String userName = user == null ? null : user.getName(); + + ADCancelTaskRequest cancelTaskRequest = new ADCancelTaskRequest(detectorId, taskId, userName, dataNodes); + client + .execute( + ADCancelTaskAction.INSTANCE, + cancelTaskRequest, + ActionListener + .wrap(response -> { listener.onResponse(new AnomalyDetectorJobResponse(taskId, 0, 0, 0, RestStatus.OK)); }, e -> { + logger.error("Failed to cancel AD task " + taskId + ", detector id: " + detectorId, e); + listener.onFailure(e); + }) + ); + } + + private boolean lastUpdateTimeOfHistoricalTaskExpired(ADTask adTask) { + // Wait at least 10 seconds. Piece interval seconds is dynamic setting, user could change it to a smaller value. + int waitingTime = Math.max(2 * pieceIntervalSeconds, 10); + return adTask.getLastUpdateTime().plus(waitingTime, ChronoUnit.SECONDS).isBefore(Instant.now()); + } + + private void resetTaskStateAsStopped( + ADTask adTask, + ExecutorFunction function, + TransportService transportService, + ActionListener listener + ) { + cleanDetectorCache(adTask, transportService, () -> { + String taskId = adTask.getTaskId(); + Map updatedFields = ImmutableMap.of(STATE_FIELD, ADTaskState.STOPPED.name()); + updateADTask(taskId, updatedFields, ActionListener.wrap(r -> { + adTask.setState(ADTaskState.STOPPED.name()); + if (function != null) { + function.execute(); + } + // For realtime anomaly detection, we only create detector level task, no entity level realtime task. + if (ADTaskType.HISTORICAL_HC_DETECTOR.name().equals(adTask.getTaskType())) { + // Reset running entity tasks as STOPPED + resetEntityTasksAsStopped(taskId); + } + }, e -> { + logger.error("Failed to update task state as STOPPED for task " + taskId, e); + listener.onFailure(e); + })); + }, listener); + } + + private void resetEntityTasksAsStopped(String detectorTaskId) { + UpdateByQueryRequest updateByQueryRequest = new UpdateByQueryRequest(); + updateByQueryRequest.indices(DETECTION_STATE_INDEX); + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(PARENT_TASK_ID_FIELD, detectorTaskId)); + query.filter(new TermQueryBuilder(TASK_TYPE_FIELD, ADTaskType.HISTORICAL_HC_ENTITY.name())); + query.filter(new TermsQueryBuilder(STATE_FIELD, NOT_ENDED_STATES)); + updateByQueryRequest.setQuery(query); + updateByQueryRequest.setRefresh(true); + String script = String.format(Locale.ROOT, "ctx._source.%s='%s';", STATE_FIELD, ADTaskState.STOPPED.name()); + updateByQueryRequest.setScript(new Script(script)); + + client.execute(UpdateByQueryAction.INSTANCE, updateByQueryRequest, ActionListener.wrap(r -> { + List bulkFailures = r.getBulkFailures(); + if (isNullOrEmpty(bulkFailures)) { + logger.debug("Updated {} child entity tasks state for detector task {}", r.getUpdated(), detectorTaskId); + } else { + logger.error("Failed to update child entity task's state for detector task {} ", detectorTaskId); + } + }, e -> logger.error("Exception happened when update child entity task's state for detector task " + detectorTaskId, e))); + } + + /** + * Clean detector cache on coordinating node. + * If task's coordinating node is still in cluster, will forward stop + * task request to coordinating node, then coordinating node will + * remove detector from cache. + * If task's coordinating node is not in cluster, we don't need to + * forward stop task request to coordinating node. + * [Important!] Make sure listener returns in function + * + * @param adTask AD task + * @param transportService transport service + * @param function will execute it when detector cache cleaned successfully or coordinating node left cluster + * @param listener action listener + * @param response type of listener + */ + public void cleanDetectorCache( + ADTask adTask, + TransportService transportService, + ExecutorFunction function, + ActionListener listener + ) { + String coordinatingNode = adTask.getCoordinatingNode(); + String detectorId = adTask.getId(); + String taskId = adTask.getTaskId(); + try { + forwardADTaskToCoordinatingNode( + adTask, + ADTaskAction.CLEAN_CACHE, + transportService, + ActionListener.wrap(r -> { function.execute(); }, e -> { + logger.error("Failed to clear detector cache on coordinating node " + coordinatingNode, e); + listener.onFailure(e); + }) + ); + } catch (ResourceNotFoundException e) { + logger + .warn( + "Task coordinating node left cluster, taskId: {}, detectorId: {}, coordinatingNode: {}", + taskId, + detectorId, + coordinatingNode + ); + function.execute(); + } catch (Exception e) { + logger.error("Failed to forward clean cache event for detector " + detectorId + ", task " + taskId, e); + listener.onFailure(e); + } + } + + protected void cleanDetectorCache(ADTask adTask, TransportService transportService, ExecutorFunction function) { + String detectorId = adTask.getId(); + String taskId = adTask.getTaskId(); + cleanDetectorCache( + adTask, + transportService, + function, + ActionListener + .wrap( + r -> { logger.debug("Successfully cleaned cache for detector {}, task {}", detectorId, taskId); }, + e -> { logger.error("Failed to clean cache for detector " + detectorId + ", task " + taskId, e); } + ) + ); + } + + /** + * Get latest historical AD task profile. + * Will not reset task state in this method. + * + * @param detectorId detector id + * @param transportService transport service + * @param profile detector profile + * @param listener action listener + */ + public void getLatestHistoricalTaskProfile( + String detectorId, + TransportService transportService, + DetectorProfile profile, + ActionListener listener + ) { + getAndExecuteOnLatestADTask(detectorId, null, null, HISTORICAL_DETECTOR_TASK_TYPES, adTask -> { + if (adTask.isPresent()) { + getADTaskProfile(adTask.get(), ActionListener.wrap(adTaskProfile -> { + DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder(); + profileBuilder.adTaskProfile(adTaskProfile); + DetectorProfile detectorProfile = profileBuilder.build(); + detectorProfile.merge(profile); + listener.onResponse(detectorProfile); + }, e -> { + logger.error("Failed to get AD task profile for task " + adTask.get().getTaskId(), e); + listener.onFailure(e); + })); + } else { + DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder(); + listener.onResponse(profileBuilder.build()); + } + }, transportService, false, listener); + } + + /** + * Get AD task profile. + * @param adDetectorLevelTask detector level task + * @param listener action listener + */ + private void getADTaskProfile(ADTask adDetectorLevelTask, ActionListener listener) { + String detectorId = adDetectorLevelTask.getId(); + + hashRing.getAllEligibleDataNodesWithKnownAdVersion(dataNodes -> { + ADTaskProfileRequest adTaskProfileRequest = new ADTaskProfileRequest(detectorId, dataNodes); + client.execute(ADTaskProfileAction.INSTANCE, adTaskProfileRequest, ActionListener.wrap(response -> { + if (response.hasFailures()) { + listener.onFailure(response.failures().get(0)); + return; + } + + List adEntityTaskProfiles = new ArrayList<>(); + ADTaskProfile detectorTaskProfile = new ADTaskProfile(adDetectorLevelTask); + for (ADTaskProfileNodeResponse node : response.getNodes()) { + ADTaskProfile taskProfile = node.getAdTaskProfile(); + if (taskProfile != null) { + if (taskProfile.getNodeId() != null) { + // HC detector: task profile from coordinating node + // Single entity detector: task profile from worker node + detectorTaskProfile.setTaskId(taskProfile.getTaskId()); + detectorTaskProfile.setShingleSize(taskProfile.getShingleSize()); + detectorTaskProfile.setRcfTotalUpdates(taskProfile.getRcfTotalUpdates()); + detectorTaskProfile.setThresholdModelTrained(taskProfile.getThresholdModelTrained()); + detectorTaskProfile.setThresholdModelTrainingDataSize(taskProfile.getThresholdModelTrainingDataSize()); + detectorTaskProfile.setModelSizeInBytes(taskProfile.getModelSizeInBytes()); + detectorTaskProfile.setNodeId(taskProfile.getNodeId()); + detectorTaskProfile.setTotalEntitiesCount(taskProfile.getTotalEntitiesCount()); + detectorTaskProfile.setDetectorTaskSlots(taskProfile.getDetectorTaskSlots()); + detectorTaskProfile.setPendingEntitiesCount(taskProfile.getPendingEntitiesCount()); + detectorTaskProfile.setRunningEntitiesCount(taskProfile.getRunningEntitiesCount()); + detectorTaskProfile.setRunningEntities(taskProfile.getRunningEntities()); + detectorTaskProfile.setAdTaskType(taskProfile.getAdTaskType()); + } + if (taskProfile.getEntityTaskProfiles() != null) { + adEntityTaskProfiles.addAll(taskProfile.getEntityTaskProfiles()); + } + } + } + if (adEntityTaskProfiles != null && adEntityTaskProfiles.size() > 0) { + detectorTaskProfile.setEntityTaskProfiles(adEntityTaskProfiles); + } + listener.onResponse(detectorTaskProfile); + }, e -> { + logger.error("Failed to get task profile for task " + adDetectorLevelTask.getTaskId(), e); + listener.onFailure(e); + })); + }, listener); + + } + + private String validateDetector(AnomalyDetector detector) { + String error = null; + if (detector.getFeatureAttributes().size() == 0) { + error = "Can't start detector job as no features configured"; + } else if (detector.getEnabledFeatureIds().size() == 0) { + error = "Can't start detector job as no enabled features configured"; + } + return error; + } + + private void updateLatestFlagOfOldTasksAndCreateNewTask( + AnomalyDetector detector, + DateRange detectionDateRange, + User user, + ActionListener listener + ) { + UpdateByQueryRequest updateByQueryRequest = new UpdateByQueryRequest(); + updateByQueryRequest.indices(DETECTION_STATE_INDEX); + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, detector.getId())); + query.filter(new TermQueryBuilder(IS_LATEST_FIELD, true)); + // make sure we reset all latest task as false when user switch from single entity to HC, vice versa. + query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(getADTaskTypes(detectionDateRange, true)))); + updateByQueryRequest.setQuery(query); + updateByQueryRequest.setRefresh(true); + String script = String.format(Locale.ROOT, "ctx._source.%s=%s;", IS_LATEST_FIELD, false); + updateByQueryRequest.setScript(new Script(script)); + + client.execute(UpdateByQueryAction.INSTANCE, updateByQueryRequest, ActionListener.wrap(r -> { + List bulkFailures = r.getBulkFailures(); + if (bulkFailures.isEmpty()) { + // Realtime AD coordinating node is chosen by job scheduler, we won't know it until realtime AD job + // runs. Just set realtime AD coordinating node as null here, and AD job runner will reset correct + // coordinating node once realtime job starts. + // For historical analysis, this method will be called on coordinating node, so we can set coordinating + // node as local node. + String coordinatingNode = detectionDateRange == null ? null : clusterService.localNode().getId(); + createNewADTask(detector, detectionDateRange, user, coordinatingNode, listener); + } else { + logger.error("Failed to update old task's state for detector: {}, response: {} ", detector.getId(), r.toString()); + listener.onFailure(bulkFailures.get(0).getCause()); + } + }, e -> { + logger.error("Failed to reset old tasks as not latest for detector " + detector.getId(), e); + listener.onFailure(e); + })); + } + + private void createNewADTask( + AnomalyDetector detector, + DateRange detectionDateRange, + User user, + String coordinatingNode, + ActionListener listener + ) { + String userName = user == null ? null : user.getName(); + Instant now = Instant.now(); + String taskType = getADTaskType(detector, detectionDateRange).name(); + ADTask adTask = new ADTask.Builder() + .detectorId(detector.getId()) + .detector(detector) + .isLatest(true) + .taskType(taskType) + .executionStartTime(now) + .taskProgress(0.0f) + .initProgress(0.0f) + .state(ADTaskState.CREATED.name()) + .lastUpdateTime(now) + .startedBy(userName) + .coordinatingNode(coordinatingNode) + .detectionDateRange(detectionDateRange) + .user(user) + .build(); + + createADTaskDirectly( + adTask, + r -> onIndexADTaskResponse( + r, + adTask, + (response, delegatedListener) -> cleanOldAdTaskDocs(response, adTask, delegatedListener), + listener + ), + listener + ); + } + + /** + * Create AD task directly without checking index exists of not. + * [Important!] Make sure listener returns in function + * + * @param adTask AD task + * @param function consumer function + * @param listener action listener + * @param action listener response type + */ + public void createADTaskDirectly(ADTask adTask, Consumer function, ActionListener listener) { + IndexRequest request = new IndexRequest(DETECTION_STATE_INDEX); + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + request + .source(adTask.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.index(request, ActionListener.wrap(r -> function.accept(r), e -> { + logger.error("Failed to create AD task for detector " + adTask.getId(), e); + listener.onFailure(e); + })); + } catch (Exception e) { + logger.error("Failed to create AD task for detector " + adTask.getId(), e); + listener.onFailure(e); + } + } + + private void onIndexADTaskResponse( + IndexResponse response, + ADTask adTask, + BiConsumer> function, + ActionListener listener + ) { + if (response == null || response.getResult() != CREATED) { + String errorMsg = getShardsFailure(response); + listener.onFailure(new OpenSearchStatusException(errorMsg, response.status())); + return; + } + adTask.setTaskId(response.getId()); + ActionListener delegatedListener = ActionListener.wrap(r -> { listener.onResponse(r); }, e -> { + handleADTaskException(adTask, e); + if (e instanceof DuplicateTaskException) { + listener.onFailure(new OpenSearchStatusException(DETECTOR_IS_RUNNING, RestStatus.BAD_REQUEST)); + } else { + // For historical AD task, clear historical task if any other exception happened. + // For realtime AD, task cache will be inited when realtime job starts, check + // ADTaskManager#initRealtimeTaskCacheAndCleanupStaleCache for details. Here the + // realtime task cache not inited yet when create AD task, so no need to cleanup. + if (adTask.isHistoricalTask()) { + adTaskCacheManager.removeHistoricalTaskCache(adTask.getId()); + } + listener.onFailure(e); + } + }); + try { + // Put detector id in cache. If detector id already in cache, will throw + // DuplicateTaskException. This is to solve race condition when user send + // multiple start request for one historical detector. + if (adTask.isHistoricalTask()) { + adTaskCacheManager.add(adTask.getId(), adTask); + } + } catch (Exception e) { + delegatedListener.onFailure(e); + return; + } + if (function != null) { + function.accept(response, delegatedListener); + } + } + + private void cleanOldAdTaskDocs(IndexResponse response, ADTask adTask, ActionListener delegatedListener) { + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, adTask.getId())); + query.filter(new TermQueryBuilder(IS_LATEST_FIELD, false)); + + if (adTask.isHistoricalTask()) { + // If historical task, only delete detector level task. It may take longer time to delete entity tasks. + // We will delete child task (entity task) of detector level task in hourly cron job. + query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(HISTORICAL_DETECTOR_TASK_TYPES))); + } else { + // We don't have entity level task for realtime detection, so will delete all tasks. + query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(REALTIME_TASK_TYPES))); + } + + SearchRequest searchRequest = new SearchRequest(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder + .query(query) + .sort(EXECUTION_START_TIME_FIELD, SortOrder.DESC) + // Search query "from" starts from 0. + .from(maxOldAdTaskDocsPerDetector) + .size(MAX_OLD_AD_TASK_DOCS); + searchRequest.source(sourceBuilder).indices(DETECTION_STATE_INDEX); + String detectorId = adTask.getId(); + + deleteTaskDocs(detectorId, searchRequest, () -> { + if (adTask.isHistoricalTask()) { + // run batch result action for historical detection + runBatchResultAction(response, adTask, delegatedListener); + } else { + // return response directly for realtime detection + AnomalyDetectorJobResponse anomalyDetectorJobResponse = new AnomalyDetectorJobResponse( + response.getId(), + response.getVersion(), + response.getSeqNo(), + response.getPrimaryTerm(), + RestStatus.OK + ); + delegatedListener.onResponse(anomalyDetectorJobResponse); + } + }, delegatedListener); + } + + protected void deleteTaskDocs( + String detectorId, + SearchRequest searchRequest, + ExecutorFunction function, + ActionListener listener + ) { + ActionListener searchListener = ActionListener.wrap(r -> { + Iterator iterator = r.getHits().iterator(); + if (iterator.hasNext()) { + BulkRequest bulkRequest = new BulkRequest(); + while (iterator.hasNext()) { + SearchHit searchHit = iterator.next(); + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, searchHit.getSourceRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + ADTask adTask = ADTask.parse(parser, searchHit.getId()); + logger.debug("Delete old task: {} of detector: {}", adTask.getTaskId(), adTask.getId()); + bulkRequest.add(new DeleteRequest(DETECTION_STATE_INDEX).id(adTask.getTaskId())); + } catch (Exception e) { + listener.onFailure(e); + } + } + client.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(res -> { + logger.info("Old AD tasks deleted for detector {}", detectorId); + BulkItemResponse[] bulkItemResponses = res.getItems(); + if (bulkItemResponses != null && bulkItemResponses.length > 0) { + for (BulkItemResponse bulkItemResponse : bulkItemResponses) { + if (!bulkItemResponse.isFailed()) { + logger.debug("Add detector task into cache. Task id: {}", bulkItemResponse.getId()); + // add deleted task in cache and delete its child tasks and AD results + adTaskCacheManager.addDeletedDetectorTask(bulkItemResponse.getId()); + } + } + } + // delete child tasks and AD results of this task + cleanChildTasksAndADResultsOfDeletedTask(); + + function.execute(); + }, e -> { + logger.warn("Failed to clean AD tasks for detector " + detectorId, e); + listener.onFailure(e); + })); + } else { + function.execute(); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + function.execute(); + } else { + listener.onFailure(e); + } + }); + + client.search(searchRequest, searchListener); + } + + /** + * Poll deleted detector task from cache and delete its child tasks and AD results. + */ + public void cleanChildTasksAndADResultsOfDeletedTask() { + if (!adTaskCacheManager.hasDeletedDetectorTask()) { + return; + } + threadPool.schedule(() -> { + String taskId = adTaskCacheManager.pollDeletedDetectorTask(); + if (taskId == null) { + return; + } + DeleteByQueryRequest deleteADResultsRequest = new DeleteByQueryRequest(ALL_AD_RESULTS_INDEX_PATTERN); + deleteADResultsRequest.setQuery(new TermsQueryBuilder(TASK_ID_FIELD, taskId)); + client.execute(DeleteByQueryAction.INSTANCE, deleteADResultsRequest, ActionListener.wrap(res -> { + logger.debug("Successfully deleted AD results of task " + taskId); + DeleteByQueryRequest deleteChildTasksRequest = new DeleteByQueryRequest(DETECTION_STATE_INDEX); + deleteChildTasksRequest.setQuery(new TermsQueryBuilder(PARENT_TASK_ID_FIELD, taskId)); + + client.execute(DeleteByQueryAction.INSTANCE, deleteChildTasksRequest, ActionListener.wrap(r -> { + logger.debug("Successfully deleted child tasks of task " + taskId); + cleanChildTasksAndADResultsOfDeletedTask(); + }, e -> { logger.error("Failed to delete child tasks of task " + taskId, e); })); + }, ex -> { logger.error("Failed to delete AD results for task " + taskId, ex); })); + }, TimeValue.timeValueSeconds(DEFAULT_MAINTAIN_INTERVAL_IN_SECONDS), AD_BATCH_TASK_THREAD_POOL_NAME); + } + + private void runBatchResultAction(IndexResponse response, ADTask adTask, ActionListener listener) { + client.execute(ADBatchAnomalyResultAction.INSTANCE, new ADBatchAnomalyResultRequest(adTask), ActionListener.wrap(r -> { + String remoteOrLocal = r.isRunTaskRemotely() ? "remote" : "local"; + logger + .info( + "AD task {} of detector {} dispatched to {} node {}", + adTask.getTaskId(), + adTask.getId(), + remoteOrLocal, + r.getNodeId() + ); + AnomalyDetectorJobResponse anomalyDetectorJobResponse = new AnomalyDetectorJobResponse( + response.getId(), + response.getVersion(), + response.getSeqNo(), + response.getPrimaryTerm(), + RestStatus.OK + ); + listener.onResponse(anomalyDetectorJobResponse); + }, e -> listener.onFailure(e))); + } + + /** + * Handle exceptions for AD task. Update task state and record error message. + * + * @param adTask AD task + * @param e exception + */ + public void handleADTaskException(ADTask adTask, Exception e) { + // TODO: handle timeout exception + String state = ADTaskState.FAILED.name(); + Map updatedFields = new HashMap<>(); + if (e instanceof DuplicateTaskException) { + // If user send multiple start detector request, we will meet race condition. + // Cache manager will put first request in cache and throw DuplicateTaskException + // for the second request. We will delete the second task. + logger + .warn( + "There is already one running task for detector, detectorId:" + + adTask.getId() + + ". Will delete task " + + adTask.getTaskId() + ); + deleteADTask(adTask.getTaskId()); + return; + } + if (e instanceof TaskCancelledException) { + logger.info("AD task cancelled, taskId: {}, detectorId: {}", adTask.getTaskId(), adTask.getId()); + state = ADTaskState.STOPPED.name(); + String stoppedBy = ((TaskCancelledException) e).getCancelledBy(); + if (stoppedBy != null) { + updatedFields.put(STOPPED_BY_FIELD, stoppedBy); + } + } else { + logger.error("Failed to execute AD batch task, task id: " + adTask.getTaskId() + ", detector id: " + adTask.getId(), e); + } + updatedFields.put(ERROR_FIELD, getErrorMessage(e)); + updatedFields.put(STATE_FIELD, state); + updatedFields.put(EXECUTION_END_TIME_FIELD, Instant.now().toEpochMilli()); + updateADTask(adTask.getTaskId(), updatedFields); + } + + /** + * Update AD task with specific fields. + * + * @param taskId AD task id + * @param updatedFields updated fields, key: filed name, value: new value + */ + public void updateADTask(String taskId, Map updatedFields) { + updateADTask(taskId, updatedFields, ActionListener.wrap(response -> { + if (response.status() == RestStatus.OK) { + logger.debug("Updated AD task successfully: {}, task id: {}", response.status(), taskId); + } else { + logger.error("Failed to update AD task {}, status: {}", taskId, response.status()); + } + }, e -> { logger.error("Failed to update task: " + taskId, e); })); + } + + /** + * Update AD task for specific fields. + * + * @param taskId task id + * @param updatedFields updated fields, key: filed name, value: new value + * @param listener action listener + */ + public void updateADTask(String taskId, Map updatedFields, ActionListener listener) { + UpdateRequest updateRequest = new UpdateRequest(DETECTION_STATE_INDEX, taskId); + Map updatedContent = new HashMap<>(); + updatedContent.putAll(updatedFields); + updatedContent.put(LAST_UPDATE_TIME_FIELD, Instant.now().toEpochMilli()); + updateRequest.doc(updatedContent); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.update(updateRequest, listener); + } + + /** + * Delete AD task with task id. + * + * @param taskId AD task id + */ + public void deleteADTask(String taskId) { + deleteADTask( + taskId, + ActionListener + .wrap( + r -> { logger.info("Deleted AD task {} with status: {}", taskId, r.status()); }, + e -> { logger.error("Failed to delete AD task " + taskId, e); } + ) + ); + } + + /** + * Delete AD task with task id. + * + * @param taskId AD task id + * @param listener action listener + */ + public void deleteADTask(String taskId, ActionListener listener) { + DeleteRequest deleteRequest = new DeleteRequest(DETECTION_STATE_INDEX, taskId); + client.delete(deleteRequest, listener); + } + + /** + * Cancel running task by detector id. + * + * @param detectorId detector id + * @param detectorTaskId detector level task id + * @param reason reason to cancel AD task + * @param userName which user cancel the AD task + * @return AD task cancellation state + */ + public ADTaskCancellationState cancelLocalTaskByDetectorId(String detectorId, String detectorTaskId, String reason, String userName) { + ADTaskCancellationState cancellationState = adTaskCacheManager.cancelByDetectorId(detectorId, detectorTaskId, reason, userName); + logger + .debug( + "Cancelled AD task for detector: {}, state: {}, cancelled by: {}, reason: {}", + detectorId, + cancellationState, + userName, + reason + ); + return cancellationState; + } + + /** + * Delete AD tasks docs. + * [Important!] Make sure listener returns in function + * + * @param detectorId detector id + * @param function AD function + * @param listener action listener + */ + public void deleteADTasks(String detectorId, ExecutorFunction function, ActionListener listener) { + DeleteByQueryRequest request = new DeleteByQueryRequest(DETECTION_STATE_INDEX); + + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, detectorId)); + + request.setQuery(query); + client.execute(DeleteByQueryAction.INSTANCE, request, ActionListener.wrap(r -> { + if (r.getBulkFailures() == null || r.getBulkFailures().size() == 0) { + logger.info("AD tasks deleted for detector {}", detectorId); + deleteADResultOfDetector(detectorId); + function.execute(); + } else { + listener.onFailure(new OpenSearchStatusException("Failed to delete all AD tasks", RestStatus.INTERNAL_SERVER_ERROR)); + } + }, e -> { + logger.info("Failed to delete AD tasks for " + detectorId, e); + if (e instanceof IndexNotFoundException) { + deleteADResultOfDetector(detectorId); + function.execute(); + } else { + listener.onFailure(e); + } + })); + } + + private void deleteADResultOfDetector(String detectorId) { + if (!deleteADResultWhenDeleteDetector) { + logger.info("Won't delete ad result for {} as delete AD result setting is disabled", detectorId); + return; + } + logger.info("Start to delete AD results of detector {}", detectorId); + DeleteByQueryRequest deleteADResultsRequest = new DeleteByQueryRequest(ALL_AD_RESULTS_INDEX_PATTERN); + deleteADResultsRequest.setQuery(new TermQueryBuilder(DETECTOR_ID_FIELD, detectorId)); + client + .execute( + DeleteByQueryAction.INSTANCE, + deleteADResultsRequest, + ActionListener + .wrap(response -> { logger.debug("Successfully deleted AD results of detector " + detectorId); }, exception -> { + logger.error("Failed to delete AD results of detector " + detectorId, exception); + adTaskCacheManager.addDeletedDetector(detectorId); + }) + ); + } + + /** + * Clean AD results of deleted detector. + */ + public void cleanADResultOfDeletedDetector() { + String detectorId = adTaskCacheManager.pollDeletedDetector(); + if (detectorId != null) { + deleteADResultOfDetector(detectorId); + } + } + + /** + * Update latest AD task of detector. + * + * @param detectorId detector id + * @param taskTypes task types + * @param updatedFields updated fields, key: filed name, value: new value + * @param listener action listener + */ + public void updateLatestADTask( + String detectorId, + List taskTypes, + Map updatedFields, + ActionListener listener + ) { + getAndExecuteOnLatestDetectorLevelTask(detectorId, taskTypes, (adTask) -> { + if (adTask.isPresent()) { + updateADTask(adTask.get().getTaskId(), updatedFields, listener); + } else { + listener.onFailure(new ResourceNotFoundException(detectorId, CAN_NOT_FIND_LATEST_TASK)); + } + }, null, false, listener); + } + + /** + * Update latest realtime task. + * + * @param detectorId detector id + * @param state task state + * @param error error + * @param transportService transport service + * @param listener action listener + */ + public void stopLatestRealtimeTask( + String detectorId, + ADTaskState state, + Exception error, + TransportService transportService, + ActionListener listener + ) { + getAndExecuteOnLatestDetectorLevelTask(detectorId, REALTIME_TASK_TYPES, (adTask) -> { + if (adTask.isPresent() && !adTask.get().isDone()) { + Map updatedFields = new HashMap<>(); + updatedFields.put(ADTask.STATE_FIELD, state.name()); + if (error != null) { + updatedFields.put(ADTask.ERROR_FIELD, error.getMessage()); + } + ExecutorFunction function = () -> updateADTask(adTask.get().getTaskId(), updatedFields, ActionListener.wrap(r -> { + if (error == null) { + listener.onResponse(new AnomalyDetectorJobResponse(detectorId, 0, 0, 0, RestStatus.OK)); + } else { + listener.onFailure(error); + } + }, e -> { listener.onFailure(e); })); + + String coordinatingNode = adTask.get().getCoordinatingNode(); + if (coordinatingNode != null && transportService != null) { + cleanDetectorCache(adTask.get(), transportService, function, listener); + } else { + function.execute(); + } + } else { + listener.onFailure(new OpenSearchStatusException("Anomaly detector job is already stopped: " + detectorId, RestStatus.OK)); + } + }, null, false, listener); + } + + /** + * Update realtime task cache on realtime detector's coordinating node. + * + * @param detectorId detector id + * @param state new state + * @param rcfTotalUpdates rcf total updates + * @param detectorIntervalInMinutes detector interval in minutes + * @param error error + * @param listener action listener + */ + public void updateLatestRealtimeTaskOnCoordinatingNode( + String detectorId, + String state, + Long rcfTotalUpdates, + Long detectorIntervalInMinutes, + String error, + ActionListener listener + ) { + Float initProgress = null; + String newState = null; + // calculate init progress and task state with RCF total updates + if (detectorIntervalInMinutes != null && rcfTotalUpdates != null) { + newState = ADTaskState.INIT.name(); + if (rcfTotalUpdates < NUM_MIN_SAMPLES) { + initProgress = (float) rcfTotalUpdates / NUM_MIN_SAMPLES; + } else { + newState = ADTaskState.RUNNING.name(); + initProgress = 1.0f; + } + } + // Check if new state is not null and override state calculated with rcf total updates + if (state != null) { + newState = state; + } + + error = Optional.ofNullable(error).orElse(""); + if (!adTaskCacheManager.isRealtimeTaskChangeNeeded(detectorId, newState, initProgress, error)) { + // If task not changed, no need to update, just return + listener.onResponse(null); + return; + } + Map updatedFields = new HashMap<>(); + updatedFields.put(COORDINATING_NODE_FIELD, clusterService.localNode().getId()); + if (initProgress != null) { + updatedFields.put(INIT_PROGRESS_FIELD, initProgress); + updatedFields.put(ESTIMATED_MINUTES_LEFT_FIELD, Math.max(0, NUM_MIN_SAMPLES - rcfTotalUpdates) * detectorIntervalInMinutes); + } + if (newState != null) { + updatedFields.put(STATE_FIELD, newState); + } + if (error != null) { + updatedFields.put(ERROR_FIELD, error); + } + Float finalInitProgress = initProgress; + // Variable used in lambda expression should be final or effectively final + String finalError = error; + String finalNewState = newState; + updateLatestADTask(detectorId, ADTaskType.REALTIME_TASK_TYPES, updatedFields, ActionListener.wrap(r -> { + logger.debug("Updated latest realtime AD task successfully for detector {}", detectorId); + adTaskCacheManager.updateRealtimeTaskCache(detectorId, finalNewState, finalInitProgress, finalError); + listener.onResponse(r); + }, e -> { + logger.error("Failed to update realtime task for detector " + detectorId, e); + listener.onFailure(e); + })); + } + + /** + * Init realtime task cache and clean up realtime task cache on old coordinating node. Realtime AD + * depends on job scheduler to choose node (job coordinating node) to run AD job. Nodes have primary + * or replica shard of AD job index are candidate to run AD job. Job scheduler will build hash ring + * on these candidate nodes and choose one to run AD job. If AD job index shard relocated, for example + * new node added into cluster, then job scheduler will rebuild hash ring and may choose different + * node to run AD job. So we need to init realtime task cache on new AD job coordinating node and + * clean up cache on old coordinating node. + * + * If realtime task cache inited for the first time on this node, listener will return true; otherwise + * listener will return false. + * + * @param detectorId detector id + * @param detector anomaly detector + * @param transportService transport service + * @param listener listener + */ + public void initRealtimeTaskCacheAndCleanupStaleCache( + String detectorId, + AnomalyDetector detector, + TransportService transportService, + ActionListener listener + ) { + try { + if (adTaskCacheManager.getRealtimeTaskCache(detectorId) != null) { + listener.onResponse(false); + return; + } + + getAndExecuteOnLatestDetectorLevelTask(detectorId, REALTIME_TASK_TYPES, (adTaskOptional) -> { + if (!adTaskOptional.isPresent()) { + logger.debug("Can't find realtime task for detector {}, init realtime task cache directly", detectorId); + ExecutorFunction function = () -> createNewADTask( + detector, + null, + detector.getUser(), + clusterService.localNode().getId(), + ActionListener.wrap(r -> { + logger.info("Recreate realtime task successfully for detector {}", detectorId); + adTaskCacheManager.initRealtimeTaskCache(detectorId, detector.getIntervalInMilliseconds()); + listener.onResponse(true); + }, e -> { + logger.error("Failed to recreate realtime task for detector " + detectorId, e); + listener.onFailure(e); + }) + ); + recreateRealtimeTask(function, listener); + return; + } + + ADTask adTask = adTaskOptional.get(); + String localNodeId = clusterService.localNode().getId(); + String oldCoordinatingNode = adTask.getCoordinatingNode(); + if (oldCoordinatingNode != null && !localNodeId.equals(oldCoordinatingNode)) { + logger + .warn( + "AD realtime job coordinating node changed from {} to this node {} for detector {}", + oldCoordinatingNode, + localNodeId, + detectorId + ); + cleanDetectorCache(adTask, transportService, () -> { + logger + .info( + "Realtime task cache cleaned on old coordinating node {} for detector {}", + oldCoordinatingNode, + detectorId + ); + adTaskCacheManager.initRealtimeTaskCache(detectorId, detector.getIntervalInMilliseconds()); + listener.onResponse(true); + }, listener); + } else { + logger.info("Init realtime task cache for detector {}", detectorId); + adTaskCacheManager.initRealtimeTaskCache(detectorId, detector.getIntervalInMilliseconds()); + listener.onResponse(true); + } + }, transportService, false, listener); + } catch (Exception e) { + logger.error("Failed to init realtime task cache for " + detectorId, e); + listener.onFailure(e); + } + } + + private void recreateRealtimeTask(ExecutorFunction function, ActionListener listener) { + if (detectionIndices.doesStateIndexExist()) { + function.execute(); + } else { + // If detection index doesn't exist, create index and execute function. + detectionIndices.initStateIndex(ActionListener.wrap(r -> { + if (r.isAcknowledged()) { + logger.info("Created {} with mappings.", DETECTION_STATE_INDEX); + function.execute(); + } else { + String error = String.format(Locale.ROOT, CREATE_INDEX_NOT_ACKNOWLEDGED, DETECTION_STATE_INDEX); + logger.warn(error); + listener.onFailure(new OpenSearchStatusException(error, RestStatus.INTERNAL_SERVER_ERROR)); + } + }, e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException) { + function.execute(); + } else { + logger.error("Failed to init anomaly detection state index", e); + listener.onFailure(e); + } + })); + } + } + + public void refreshRealtimeJobRunTime(String detectorId) { + adTaskCacheManager.refreshRealtimeJobRunTime(detectorId); + } + + public void removeRealtimeTaskCache(String detectorId) { + adTaskCacheManager.removeRealtimeTaskCache(detectorId); + } + + /** + * Send entity task done message to coordinating node. + * + * @param adTask AD task + * @param exception exception of entity task + * @param transportService transport service + */ + protected void entityTaskDone(ADTask adTask, Exception exception, TransportService transportService) { + entityTaskDone( + adTask, + exception, + transportService, + ActionListener + .wrap( + r -> logger.debug("AD task forwarded to coordinating node, task id {}", adTask.getTaskId()), + e -> logger + .error( + "AD task failed to forward to coordinating node " + + adTask.getCoordinatingNode() + + " for task " + + adTask.getTaskId(), + e + ) + ) + ); + } + + private void entityTaskDone( + ADTask adTask, + Exception exception, + TransportService transportService, + ActionListener listener + ) { + try { + ADTaskAction action = getAdEntityTaskAction(adTask, exception); + forwardADTaskToCoordinatingNode(adTask, action, transportService, listener); + } catch (Exception e) { + listener.onFailure(e); + } + } + + /** + * Get AD entity task action based on exception. + * 1. If exception is null, return NEXT_ENTITY action which will poll next + * entity to run. + * 2. If exception is retryable, return PUSH_BACK_ENTITY action which will + * push entity back to pendig queue. + * 3. If exception is task cancelled exception, return CANCEL action which + * will stop HC detector run. + * + * @param adTask AD task + * @param exception exception + * @return AD task action + */ + private ADTaskAction getAdEntityTaskAction(ADTask adTask, Exception exception) { + ADTaskAction action = ADTaskAction.NEXT_ENTITY; + if (exception != null) { + adTask.setError(getErrorMessage(exception)); + if (exception instanceof LimitExceededException && isRetryableError(exception.getMessage())) { + action = ADTaskAction.PUSH_BACK_ENTITY; + } else if (exception instanceof TaskCancelledException || exception instanceof EndRunException) { + action = ADTaskAction.CANCEL; + } + } + return action; + } + + /** + * Check if error is retryable. + * + * @param error error + * @return retryable or not + */ + public boolean isRetryableError(String error) { + if (error == null) { + return false; + } + return retryableErrors.stream().filter(e -> error.contains(e)).findFirst().isPresent(); + } + + /** + * Set state for HC detector level task when all entities done. + * + * The state could be FINISHED,FAILED or STOPPED. + * 1. If input task state is FINISHED, will check FINISHED entity task count. If + * there is no FINISHED entity task, will set HC detector level task as FAILED; otherwise + * set as FINISHED. + * 2. If input task state is not FINISHED, will set HC detector level task's state as the same. + * + * @param adTask AD task + * @param state AD task state + * @param listener action listener + */ + public void setHCDetectorTaskDone(ADTask adTask, ADTaskState state, ActionListener listener) { + String detectorId = adTask.getId(); + String taskId = adTask.isEntityTask() ? adTask.getParentTaskId() : adTask.getTaskId(); + String detectorTaskId = adTask.getDetectorLevelTaskId(); + + ActionListener wrappedListener = ActionListener.wrap(response -> { + logger.info("Historical HC detector done with state: {}. Remove from cache, detector id:{}", state.name(), detectorId); + adTaskCacheManager.removeHistoricalTaskCache(detectorId); + }, e -> { + // HC detector task may fail to update as FINISHED for some edge case if failed to get updating semaphore. + // Will reset task state when get detector with task or maintain tasks in hourly cron. + if (e instanceof LimitExceededException && e.getMessage().contains(HC_DETECTOR_TASK_IS_UPDATING)) { + logger.warn("HC task is updating, skip this update for task: " + taskId); + } else { + logger.error("Failed to update task: " + taskId, e); + } + adTaskCacheManager.removeHistoricalTaskCache(detectorId); + }); + + long timeoutInMillis = 2000;// wait for 2 seconds to acquire updating HC detector task semaphore + if (state == ADTaskState.FINISHED) { + this.countEntityTasksByState(detectorTaskId, ImmutableList.of(ADTaskState.FINISHED), ActionListener.wrap(r -> { + logger.info("number of finished entity tasks: {}, for detector {}", r, adTask.getId()); + // Set task as FAILED if no finished entity task; otherwise set as FINISHED + ADTaskState hcDetectorTaskState = r == 0 ? ADTaskState.FAILED : ADTaskState.FINISHED; + // execute in AD batch task thread pool in case waiting for semaphore waste any shared OpenSearch thread pool + threadPool.executor(AD_BATCH_TASK_THREAD_POOL_NAME).execute(() -> { + updateADHCDetectorTask( + detectorId, + taskId, + ImmutableMap + .of( + STATE_FIELD, + hcDetectorTaskState.name(), + TASK_PROGRESS_FIELD, + 1.0, + EXECUTION_END_TIME_FIELD, + Instant.now().toEpochMilli() + ), + timeoutInMillis, + wrappedListener + ); + }); + + }, e -> { + logger.error("Failed to get finished entity tasks", e); + String errorMessage = getErrorMessage(e); + threadPool.executor(AD_BATCH_TASK_THREAD_POOL_NAME).execute(() -> { + updateADHCDetectorTask( + detectorId, + taskId, + ImmutableMap + .of( + STATE_FIELD, + ADTaskState.FAILED.name(),// set as FAILED if fail to get finished entity tasks. + TASK_PROGRESS_FIELD, + 1.0, + ERROR_FIELD, + errorMessage, + EXECUTION_END_TIME_FIELD, + Instant.now().toEpochMilli() + ), + timeoutInMillis, + wrappedListener + ); + }); + })); + } else { + threadPool.executor(AD_BATCH_TASK_THREAD_POOL_NAME).execute(() -> { + updateADHCDetectorTask( + detectorId, + taskId, + ImmutableMap + .of( + STATE_FIELD, + state.name(), + ERROR_FIELD, + adTask.getError(), + EXECUTION_END_TIME_FIELD, + Instant.now().toEpochMilli() + ), + timeoutInMillis, + wrappedListener + ); + }); + + } + + listener.onResponse(new AnomalyDetectorJobResponse(taskId, 0, 0, 0, RestStatus.OK)); + } + + /** + * Count entity tasks by state with detector level task id(parent task id). + * + * @param detectorTaskId detector level task id + * @param taskStates task states + * @param listener action listener + */ + public void countEntityTasksByState(String detectorTaskId, List taskStates, ActionListener listener) { + BoolQueryBuilder queryBuilder = new BoolQueryBuilder(); + queryBuilder.filter(new TermQueryBuilder(PARENT_TASK_ID_FIELD, detectorTaskId)); + if (taskStates != null && taskStates.size() > 0) { + queryBuilder.filter(new TermsQueryBuilder(STATE_FIELD, taskStates.stream().map(s -> s.name()).collect(Collectors.toList()))); + } + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(queryBuilder); + sourceBuilder.size(0); + sourceBuilder.trackTotalHits(true); + SearchRequest request = new SearchRequest(); + request.source(sourceBuilder); + request.indices(DETECTION_STATE_INDEX); + client.search(request, ActionListener.wrap(r -> { + TotalHits totalHits = r.getHits().getTotalHits(); + listener.onResponse(totalHits.value); + }, e -> listener.onFailure(e))); + } + + /** + * Update HC detector level task with default action listener. There might be + * multiple entity tasks update detector task concurrently. So we will check + * if detector task is updating or not to avoid and can only update if it's + * not updating now, otherwise it may cause version conflict exception. + * + * @param detectorId detector id + * @param taskId AD task id + * @param updatedFields updated fields, key: filed name, value: new value + */ + public void updateADHCDetectorTask(String detectorId, String taskId, Map updatedFields) { + updateADHCDetectorTask(detectorId, taskId, updatedFields, 0, ActionListener.wrap(response -> { + if (response == null) { + logger.debug("Skip updating AD task: {}", taskId); + } else if (response.status() == RestStatus.OK) { + logger.debug("Updated AD task successfully: {}, taskId: {}", response.status(), taskId); + } else { + logger.error("Failed to update AD task {}, status: {}", taskId, response.status()); + } + }, e -> { + if (e instanceof LimitExceededException && e.getMessage().contains(HC_DETECTOR_TASK_IS_UPDATING)) { + logger.warn("AD HC detector task is updating, skip this update for task: " + taskId); + } else { + logger.error("Failed to update AD HC detector task: " + taskId, e); + } + })); + } + + /** + * Update HC detector level task. There might be multiple entity tasks update + * detector task concurrently. So we will check if detector task is updating + * or not to avoid and can only update if it's not updating now, otherwise it + * may cause version conflict exception. + * + * @param detectorId detector id + * @param taskId AD task id + * @param updatedFields updated fields, key: filed name, value: new value + * @param timeoutInMillis the maximum time to wait for task updating semaphore, zero or negative means don't wait at all + * @param listener action listener + */ + private void updateADHCDetectorTask( + String detectorId, + String taskId, + Map updatedFields, + long timeoutInMillis, + ActionListener listener + ) { + try { + if (adTaskCacheManager.tryAcquireTaskUpdatingSemaphore(detectorId, timeoutInMillis)) { + try { + updateADTask( + taskId, + updatedFields, + ActionListener.runAfter(listener, () -> { adTaskCacheManager.releaseTaskUpdatingSemaphore(detectorId); }) + ); + } catch (Exception e) { + logger.error("Failed to update detector task " + taskId, e); + adTaskCacheManager.releaseTaskUpdatingSemaphore(detectorId); + listener.onFailure(e); + } + } else if (!adTaskCacheManager.isHCTaskCoordinatingNode(detectorId)) { + // It's possible that AD task cache cleaned up by other task. Return null to avoid too many failure logs. + logger.info("HC detector task cache does not exist, detectorId:{}, taskId:{}", detectorId, taskId); + listener.onResponse(null); + } else { + logger.info("HC detector task is updating, detectorId:{}, taskId:{}", detectorId, taskId); + listener.onFailure(new LimitExceededException(HC_DETECTOR_TASK_IS_UPDATING)); + } + } catch (Exception e) { + logger.error("Failed to get AD HC detector task updating semaphore " + taskId, e); + listener.onFailure(e); + } + } + + /** + * Scale task slots and check the scale delta: + * 1. If scale delta is negative, that means we need to scale down, will not start next entity. + * 2. If scale delta is positive, will start next entity in current lane. + * + * This method will be called by {@link org.opensearch.ad.transport.ForwardADTaskTransportAction}. + * + * @param adTask ad entity task + * @param transportService transport service + * @param listener action listener + */ + public void runNextEntityForHCADHistorical( + ADTask adTask, + TransportService transportService, + ActionListener listener + ) { + String detectorId = adTask.getId(); + int scaleDelta = scaleTaskSlots( + adTask, + transportService, + ActionListener + .wrap( + r -> { logger.debug("Scale up task slots done for detector {}, task {}", detectorId, adTask.getTaskId()); }, + e -> { logger.error("Failed to scale up task slots for task " + adTask.getTaskId(), e); } + ) + ); + if (scaleDelta < 0) { + logger + .warn( + "Have scaled down task slots. Will not poll next entity for detector {}, task {}, task slots: {}", + detectorId, + adTask.getTaskId(), + adTaskCacheManager.getDetectorTaskSlots(detectorId) + ); + listener.onResponse(new AnomalyDetectorJobResponse(detectorId, 0, 0, 0, RestStatus.ACCEPTED)); + return; + } + client.execute(ADBatchAnomalyResultAction.INSTANCE, new ADBatchAnomalyResultRequest(adTask), ActionListener.wrap(r -> { + String remoteOrLocal = r.isRunTaskRemotely() ? "remote" : "local"; + logger + .info( + "AD entity task {} of detector {} dispatched to {} node {}", + adTask.getTaskId(), + detectorId, + remoteOrLocal, + r.getNodeId() + ); + AnomalyDetectorJobResponse anomalyDetectorJobResponse = new AnomalyDetectorJobResponse(detectorId, 0, 0, 0, RestStatus.OK); + listener.onResponse(anomalyDetectorJobResponse); + }, e -> { listener.onFailure(e); })); + } + + /** + * Scale task slots and return scale delta. + * 1. If scale delta is positive, will forward scale task slots request to lead node. + * 2. If scale delta is negative, will decrease detector task slots in cache directly. + * + * @param adTask AD task + * @param transportService transport service + * @param scaleUpListener action listener + * @return task slots scale delta + */ + protected int scaleTaskSlots( + ADTask adTask, + TransportService transportService, + ActionListener scaleUpListener + ) { + String detectorId = adTask.getId(); + if (!scaleEntityTaskLane.tryAcquire()) { + logger.debug("Can't get scaleEntityTaskLane semaphore"); + return 0; + } + try { + int scaleDelta = detectorTaskSlotScaleDelta(detectorId); + logger.debug("start to scale task slots for detector {} with delta {}", detectorId, scaleDelta); + if (adTaskCacheManager.getAvailableNewEntityTaskLanes(detectorId) <= 0 && scaleDelta > 0) { + // scale up to run more entities in parallel + Instant lastScaleEntityTaskLaneTime = adTaskCacheManager.getLastScaleEntityTaskLaneTime(detectorId); + if (lastScaleEntityTaskLaneTime == null) { + logger.debug("lastScaleEntityTaskLaneTime is null for detector {}", detectorId); + scaleEntityTaskLane.release(); + return 0; + } + boolean lastScaleTimeExpired = lastScaleEntityTaskLaneTime + .plusMillis(SCALE_ENTITY_TASK_LANE_INTERVAL_IN_MILLIS) + .isBefore(Instant.now()); + if (lastScaleTimeExpired) { + adTaskCacheManager.refreshLastScaleEntityTaskLaneTime(detectorId); + logger.debug("Forward scale entity task lane request to lead node for detector {}", detectorId); + forwardScaleTaskSlotRequestToLeadNode( + adTask, + transportService, + ActionListener.runAfter(scaleUpListener, () -> scaleEntityTaskLane.release()) + ); + } else { + logger + .debug( + "lastScaleEntityTaskLaneTime is not expired yet: {} for detector {}", + lastScaleEntityTaskLaneTime, + detectorId + ); + scaleEntityTaskLane.release(); + } + } else { + if (scaleDelta < 0) { // scale down to release task slots for other detectors + int runningEntityCount = adTaskCacheManager.getRunningEntityCount(detectorId) + adTaskCacheManager + .getTempEntityCount(detectorId); + int assignedTaskSlots = adTaskCacheManager.getDetectorTaskSlots(detectorId); + int scaleDownDelta = Math.min(assignedTaskSlots - runningEntityCount, 0 - scaleDelta); + logger + .debug( + "Scale down task slots, scaleDelta: {}, assignedTaskSlots: {}, runningEntityCount: {}, scaleDownDelta: {}", + scaleDelta, + assignedTaskSlots, + runningEntityCount, + scaleDownDelta + ); + adTaskCacheManager.scaleDownHCDetectorTaskSlots(detectorId, scaleDownDelta); + } + scaleEntityTaskLane.release(); + } + return scaleDelta; + } catch (Exception e) { + logger.error("Failed to forward scale entity task lane request to lead node for detector " + detectorId, e); + scaleEntityTaskLane.release(); + return 0; + } + } + + /** + * Calculate scale delta for detector task slots. + * Detector's task lane limit should be less than or equal to: + * 1. Current unfinished entities: pending + running + temp + * 2. Total task slots on cluster level: eligible_data_nodes * task_slots_per_node + * 3. Max running entities per detector which is dynamic setting + * + * Task slots scale delta = task lane limit - current assigned task slots + * + * If current assigned task slots to this detector is less than task lane limit, we need + * to scale up(return positive value); otherwise we need to scale down (return negative + * value). + * + * @param detectorId detector id + * @return detector task slots scale delta + */ + public int detectorTaskSlotScaleDelta(String detectorId) { + DiscoveryNode[] eligibleDataNodes = hashRing.getNodesWithSameLocalAdVersion(); + int unfinishedEntities = adTaskCacheManager.getUnfinishedEntityCount(detectorId); + int totalTaskSlots = eligibleDataNodes.length * maxAdBatchTaskPerNode; + int taskLaneLimit = Math.min(unfinishedEntities, Math.min(totalTaskSlots, maxRunningEntitiesPerDetector)); + adTaskCacheManager.setDetectorTaskLaneLimit(detectorId, taskLaneLimit); + + int assignedTaskSlots = adTaskCacheManager.getDetectorTaskSlots(detectorId); + int scaleDelta = taskLaneLimit - assignedTaskSlots; + logger + .debug( + "Calculate task slot scale delta for detector {}, totalTaskSlots: {}, maxRunningEntitiesPerDetector: {}, " + + "unfinishedEntities: {}, taskLaneLimit: {}, assignedTaskSlots: {}, scaleDelta: {}", + detectorId, + totalTaskSlots, + maxRunningEntitiesPerDetector, + unfinishedEntities, + taskLaneLimit, + assignedTaskSlots, + scaleDelta + ); + return scaleDelta; + } + + /** + * Calculate historical analysis task progress of HC detector. + * task_progress = finished_entity_count / total_entity_count + * @param detectorId detector id + * @return task progress + */ + public float hcDetectorProgress(String detectorId) { + int entityCount = adTaskCacheManager.getTopEntityCount(detectorId); + int leftEntities = adTaskCacheManager.getPendingEntityCount(detectorId) + adTaskCacheManager.getRunningEntityCount(detectorId); + return 1 - (float) leftEntities / entityCount; + } + + /** + * Get local task profiles of detector. + * @param detectorId detector id + * @return list of AD task profile + */ + public ADTaskProfile getLocalADTaskProfilesByDetectorId(String detectorId) { + List tasksOfDetector = adTaskCacheManager.getTasksOfDetector(detectorId); + ADTaskProfile detectorTaskProfile = null; + + String localNodeId = clusterService.localNode().getId(); + if (adTaskCacheManager.isHCTaskRunning(detectorId)) { + detectorTaskProfile = new ADTaskProfile(); + if (adTaskCacheManager.isHCTaskCoordinatingNode(detectorId)) { + detectorTaskProfile.setNodeId(localNodeId); + detectorTaskProfile.setTaskId(adTaskCacheManager.getDetectorTaskId(detectorId)); + detectorTaskProfile.setDetectorTaskSlots(adTaskCacheManager.getDetectorTaskSlots(detectorId)); + detectorTaskProfile.setTotalEntitiesInited(adTaskCacheManager.topEntityInited(detectorId)); + detectorTaskProfile.setTotalEntitiesCount(adTaskCacheManager.getTopEntityCount(detectorId)); + detectorTaskProfile.setPendingEntitiesCount(adTaskCacheManager.getPendingEntityCount(detectorId)); + detectorTaskProfile.setRunningEntitiesCount(adTaskCacheManager.getRunningEntityCount(detectorId)); + detectorTaskProfile.setRunningEntities(adTaskCacheManager.getRunningEntities(detectorId)); + detectorTaskProfile.setAdTaskType(ADTaskType.HISTORICAL_HC_DETECTOR.name()); + Instant latestHCTaskRunTime = adTaskCacheManager.getLatestHCTaskRunTime(detectorId); + if (latestHCTaskRunTime != null) { + detectorTaskProfile.setLatestHCTaskRunTime(latestHCTaskRunTime.toEpochMilli()); + } + } + if (tasksOfDetector.size() > 0) { + List entityTaskProfiles = new ArrayList<>(); + + tasksOfDetector.forEach(taskId -> { + ADEntityTaskProfile entityTaskProfile = new ADEntityTaskProfile( + adTaskCacheManager.getShingle(taskId).size(), + adTaskCacheManager.getTRcfModel(taskId).getForest().getTotalUpdates(), + adTaskCacheManager.isThresholdModelTrained(taskId), + adTaskCacheManager.getThresholdModelTrainingDataSize(taskId), + adTaskCacheManager.getModelSize(taskId), + localNodeId, + adTaskCacheManager.getEntity(taskId), + taskId, + ADTaskType.HISTORICAL_HC_ENTITY.name() + ); + entityTaskProfiles.add(entityTaskProfile); + }); + detectorTaskProfile.setEntityTaskProfiles(entityTaskProfiles); + } + } else { + if (tasksOfDetector.size() > 1) { + String error = "Multiple tasks are running for detector: " + + detectorId + + ". You can stop detector to kill all running tasks."; + logger.error(error); + throw new LimitExceededException(error); + } + if (tasksOfDetector.size() == 1) { + String taskId = tasksOfDetector.get(0); + detectorTaskProfile = new ADTaskProfile( + adTaskCacheManager.getDetectorTaskId(detectorId), + adTaskCacheManager.getShingle(taskId).size(), + adTaskCacheManager.getTRcfModel(taskId).getForest().getTotalUpdates(), + adTaskCacheManager.isThresholdModelTrained(taskId), + adTaskCacheManager.getThresholdModelTrainingDataSize(taskId), + adTaskCacheManager.getModelSize(taskId), + localNodeId + ); + // Single-flow detector only has 1 task slot. + // Can't use adTaskCacheManager.getDetectorTaskSlots(detectorId) here as task may run on worker node. + // Detector task slots stored in coordinating node cache. + detectorTaskProfile.setDetectorTaskSlots(1); + } + } + threadPool + .executor(AD_BATCH_TASK_THREAD_POOL_NAME) + .execute( + () -> { + // Clean expired HC batch task run states as it may exists after HC historical analysis done if user cancel + // before querying top entities done. We will clean it in hourly cron, check "maintainRunningHistoricalTasks" + // method. Clean it up here when get task profile to release memory earlier. + adTaskCacheManager.cleanExpiredHCBatchTaskRunStates(); + } + ); + logger.debug("Local AD task profile of detector {}: {}", detectorId, detectorTaskProfile); + return detectorTaskProfile; + } + + /** + * Remove stale running entity from coordinating node cache. If no more entities, reset task as STOPPED. + * + * Explain details with an example. + * + * Note: + * CN: coordinating mode; + * WN1: worker node 1; + * WN2: worker node 2. + * [x,x] means running entity in cache. + * eX like e1: entity. + * + * Assume HC detector can run 2 entities at most and current cluster state is: + * CN: [e1, e2]; + * WN1: [e1] + * WN2: [e2] + * + * If WN1 crashes, then e1 will never removed from CN cache. User can call get detector API with "task=true" + * to reset task state. Let's say User1 and User2 call get detector API at the same time. Then User1 and User2 + * both know e1 is stale running entity and try to remove from CN cache. If User1 request arrives first, then + * it will remove e1 from CN, then CN cache will be [e2]. As we can run 2 entities per HC detector, so we can + * kick off another pending entity. Then CN cache changes to [e2, e3]. Then User2 request arrives, it will find + * e1 not in CN cache ([e2, e3]) which means e1 has been removed by other request. We can't kick off another + * pending entity for User2 request, otherwise we will run more than 2 entities for this HC detector. + * + * Why we don't put the stale running entity back to pending and retry? + * The stale entity has been ran on some worker node and the old task run may generate some or all AD results + * for the stale entity. Just because of the worker node crash or entity task done message not received by + * coordinating node, the entity will never be deleted from running entity queue. We can check if the stale + * entity has all AD results generated for the whole date range. If not, we can rerun. This make the design + * complex as we need to store model checkpoints to resume from last break point and we need to handle kinds + * of edge cases. Here we just ignore the stale running entity rather than rerun it. We plan to add scheduler + * on historical analysis, then we will store model checkpoints. Will support resuming failed tasks by then. + * //TODO: support resuming failed task + * + * @param adTask AD task + * @param entity entity value + * @param transportService transport service + * @param listener action listener + */ + public synchronized void removeStaleRunningEntity( + ADTask adTask, + String entity, + TransportService transportService, + ActionListener listener + ) { + String detectorId = adTask.getId(); + boolean removed = adTaskCacheManager.removeRunningEntity(detectorId, entity); + if (removed && adTaskCacheManager.getPendingEntityCount(detectorId) > 0) { + logger.debug("kick off next pending entities"); + this.runNextEntityForHCADHistorical(adTask, transportService, listener); + } else { + if (!adTaskCacheManager.hasEntity(detectorId)) { + setHCDetectorTaskDone(adTask, ADTaskState.STOPPED, listener); + } + } + } + + public boolean skipUpdateHCRealtimeTask(String detectorId, String error) { + ADRealtimeTaskCache realtimeTaskCache = adTaskCacheManager.getRealtimeTaskCache(detectorId); + return realtimeTaskCache != null + && realtimeTaskCache.getInitProgress() != null + && realtimeTaskCache.getInitProgress().floatValue() == 1.0 + && Objects.equals(error, realtimeTaskCache.getError()); + } + + public boolean isHCRealtimeTaskStartInitializing(String detectorId) { + ADRealtimeTaskCache realtimeTaskCache = adTaskCacheManager.getRealtimeTaskCache(detectorId); + return realtimeTaskCache != null + && realtimeTaskCache.getInitProgress() != null + && realtimeTaskCache.getInitProgress().floatValue() > 0; + } + + public String convertEntityToString(ADTask adTask) { + if (adTask == null || !adTask.isEntityTask()) { + return null; + } + AnomalyDetector detector = adTask.getDetector(); + return convertEntityToString(adTask.getEntity(), detector); + } + + /** + * Convert {@link Entity} instance to string. + * @param entity entity + * @param detector detector + * @return entity string value + */ + public String convertEntityToString(Entity entity, AnomalyDetector detector) { + if (detector.hasMultipleCategories()) { + try { + XContentBuilder builder = entity.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + return BytesReference.bytes(builder).utf8ToString(); + } catch (IOException e) { + String error = "Failed to parse entity into string"; + logger.debug(error, e); + throw new TimeSeriesException(error); + } + } + if (detector.isHighCardinality()) { + String categoryField = detector.getCategoryFields().get(0); + return entity.getAttributes().get(categoryField); + } + return null; + } + + /** + * Parse entity string value into Entity {@link Entity} instance. + * @param entityValue entity value + * @param adTask AD task + * @return Entity instance + */ + public Entity parseEntityFromString(String entityValue, ADTask adTask) { + AnomalyDetector detector = adTask.getDetector(); + if (detector.hasMultipleCategories()) { + try { + XContentParser parser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, entityValue); + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.nextToken(), parser); + return Entity.parse(parser); + } catch (IOException e) { + String error = "Failed to parse string into entity"; + logger.debug(error, e); + throw new TimeSeriesException(error); + } + } else if (detector.isHighCardinality()) { + return Entity.createSingleAttributeEntity(detector.getCategoryFields().get(0), entityValue); + } + throw new IllegalArgumentException("Fail to parse to Entity for single flow detector"); + } + + /** + * Get AD task with task id and execute listener. + * @param taskId task id + * @param listener action listener + */ + public void getADTask(String taskId, ActionListener> listener) { + GetRequest request = new GetRequest(DETECTION_STATE_INDEX, taskId); + client.get(request, ActionListener.wrap(r -> { + if (r != null && r.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + ADTask adTask = ADTask.parse(parser, r.getId()); + listener.onResponse(Optional.ofNullable(adTask)); + } catch (Exception e) { + logger.error("Failed to parse AD task " + r.getId(), e); + listener.onFailure(e); + } + } else { + listener.onResponse(Optional.empty()); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + listener.onResponse(Optional.empty()); + } else { + logger.error("Failed to get AD task " + taskId, e); + listener.onFailure(e); + } + })); + } + + /** + * Set old AD task's latest flag as false. + * @param adTasks list of AD tasks + */ + public void resetLatestFlagAsFalse(List adTasks) { + if (adTasks == null || adTasks.size() == 0) { + return; + } + BulkRequest bulkRequest = new BulkRequest(); + adTasks.forEach(task -> { + try { + task.setLatest(false); + task.setLastUpdateTime(Instant.now()); + IndexRequest indexRequest = new IndexRequest(DETECTION_STATE_INDEX) + .id(task.getTaskId()) + .source(task.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), XCONTENT_WITH_TYPE)); + bulkRequest.add(indexRequest); + } catch (Exception e) { + logger.error("Fail to parse task AD task to XContent, task id " + task.getTaskId(), e); + } + }); + + bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(res -> { + BulkItemResponse[] bulkItemResponses = res.getItems(); + if (bulkItemResponses != null && bulkItemResponses.length > 0) { + for (BulkItemResponse bulkItemResponse : bulkItemResponses) { + if (!bulkItemResponse.isFailed()) { + logger.warn("Reset AD tasks latest flag as false Successfully. Task id: {}", bulkItemResponse.getId()); + } else { + logger.warn("Failed to reset AD tasks latest flag as false. Task id: " + bulkItemResponse.getId()); + } + } + } + }, e -> { logger.warn("Failed to reset AD tasks latest flag as false", e); })); + } + + public int getLocalAdUsedBatchTaskSlot() { + return adTaskCacheManager.getTotalBatchTaskCount(); + } + + /** + * In normal case "assigned_task_slots" should always greater than or equals to "used_task_slots". One + * example may help understand. + * + * If a cluster has 3 data nodes, HC detector was assigned 3 task slots. It will start task lane one by one, + * so the timeline looks like: + * + * t1 -- node1: 1 entity running + * t2 -- node1: 1 entity running , + * node2: 1 entity running + * t3 -- node1: 1 entity running , + * node2: 1 entity running, + * node3: 1 entity running + * + * So if we check between t2 and t3, we can see assigned task slots (3) is greater than used task slots (2). + * + * Assume node1 is coordinating node, the assigned task slots will be cached on node1. If node1 left cluster, + * then we don't know how many task slots was assigned to the detector. But the detector will not send out + * more entity tasks as well due to coordinating node left. So we can calculate how many task slots used on + * node2 and node3 to calculate how many task slots available for new detector. + * @return assigned batch task slots + */ + public int getLocalAdAssignedBatchTaskSlot() { + return adTaskCacheManager.getTotalDetectorTaskSlots(); + } + + // ========================================================= + // Methods below are maintenance code triggered by hourly cron + // ========================================================= + + /** + * Maintain running historical tasks. + * Search current running latest tasks, then maintain tasks one by one. + * Get task profile to check if task is really running on worker node. + * 1. If not running, reset task state as STOPPED. + * 2. If task is running and task for HC detector, check if there is any stale running entities and + * clean up. + * + * @param transportService transport service + * @param size return how many tasks + */ + public void maintainRunningHistoricalTasks(TransportService transportService, int size) { + // Clean expired HC batch task run state cache. + adTaskCacheManager.cleanExpiredHCBatchTaskRunStates(); + + // Find owning node with highest AD version to make sure we only have 1 node maintain running historical tasks + // and we use the latest logic. + Optional owningNode = hashRing.getOwningNodeWithHighestAdVersion(AD_TASK_MAINTAINENCE_NODE_MODEL_ID); + if (!owningNode.isPresent() || !clusterService.localNode().getId().equals(owningNode.get().getId())) { + return; + } + logger.info("Start to maintain running historical tasks"); + + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(IS_LATEST_FIELD, true)); + query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(HISTORICAL_DETECTOR_TASK_TYPES))); + query.filter(new TermsQueryBuilder(STATE_FIELD, NOT_ENDED_STATES)); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + // default maintain interval is 5 seconds, so maintain 10 tasks will take at least 50 seconds. + sourceBuilder.query(query).sort(LAST_UPDATE_TIME_FIELD, SortOrder.DESC).size(size); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(sourceBuilder); + searchRequest.indices(DETECTION_STATE_INDEX); + + client.search(searchRequest, ActionListener.wrap(r -> { + if (r == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { + return; + } + ConcurrentLinkedQueue taskQueue = new ConcurrentLinkedQueue<>(); + Iterator iterator = r.getHits().iterator(); + while (iterator.hasNext()) { + SearchHit searchHit = iterator.next(); + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, searchHit.getSourceRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + taskQueue.add(ADTask.parse(parser, searchHit.getId())); + } catch (Exception e) { + logger.error("Maintaining running historical task: failed to parse AD task " + searchHit.getId(), e); + } + } + maintainRunningHistoricalTask(taskQueue, transportService); + }, e -> { + if (e instanceof IndexNotFoundException) { + // the method will be called hourly + // don't log stack trace as most of OpenSearch domains have no AD installed + logger.debug(STATE_INDEX_NOT_EXIST_MSG); + } else { + logger.error("Failed to search historical tasks in maintaining job", e); + } + })); + } + + private void maintainRunningHistoricalTask(ConcurrentLinkedQueue taskQueue, TransportService transportService) { + ADTask adTask = taskQueue.poll(); + if (adTask == null) { + return; + } + threadPool.schedule(() -> { + resetHistoricalDetectorTaskState(ImmutableList.of(adTask), () -> { + logger.debug("Finished maintaining running historical task {}", adTask.getTaskId()); + maintainRunningHistoricalTask(taskQueue, transportService); + }, + transportService, + ActionListener + .wrap( + r -> { + logger.debug("Reset historical task state done for task {}, detector {}", adTask.getTaskId(), adTask.getId()); + }, + e -> { logger.error("Failed to reset historical task state for task " + adTask.getTaskId(), e); } + ) + ); + }, TimeValue.timeValueSeconds(DEFAULT_MAINTAIN_INTERVAL_IN_SECONDS), AD_BATCH_TASK_THREAD_POOL_NAME); + } + + /** + * Maintain running realtime tasks. Check if realtime task cache expires or not. Remove realtime + * task cache directly if expired. + */ + public void maintainRunningRealtimeTasks() { + String[] detectorIds = adTaskCacheManager.getDetectorIdsInRealtimeTaskCache(); + if (detectorIds == null || detectorIds.length == 0) { + return; + } + for (int i = 0; i < detectorIds.length; i++) { + String detectorId = detectorIds[i]; + ADRealtimeTaskCache taskCache = adTaskCacheManager.getRealtimeTaskCache(detectorId); + if (taskCache != null && taskCache.expired()) { + adTaskCacheManager.removeRealtimeTaskCache(detectorId); + } + } + } +} diff --git a/src/main/java/org/opensearch/ad/task/ADTaskSlotLimit.java-e b/src/main/java/org/opensearch/ad/task/ADTaskSlotLimit.java-e new file mode 100644 index 000000000..e5491f1da --- /dev/null +++ b/src/main/java/org/opensearch/ad/task/ADTaskSlotLimit.java-e @@ -0,0 +1,36 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.task; + +public class ADTaskSlotLimit { + // Task slots assigned to detector + private Integer detectorTaskSlots; + // How many task lanes this detector can start at most + private Integer detectorTaskLaneLimit; + + public ADTaskSlotLimit(Integer detectorTaskSlots, Integer detectorTaskLaneLimit) { + this.detectorTaskSlots = detectorTaskSlots; + this.detectorTaskLaneLimit = detectorTaskLaneLimit; + } + + public Integer getDetectorTaskSlots() { + return detectorTaskSlots; + } + + public void setDetectorTaskSlots(Integer detectorTaskSlots) { + this.detectorTaskSlots = detectorTaskSlots; + } + + public void setDetectorTaskLaneLimit(Integer detectorTaskLaneLimit) { + this.detectorTaskLaneLimit = detectorTaskLaneLimit; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultAction.java-e b/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultAction.java-e new file mode 100644 index 000000000..84fe0c6fe --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultAction.java-e @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.constant.ADCommonName.AD_TASK; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; + +public class ADBatchAnomalyResultAction extends ActionType { + public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detector/" + AD_TASK; + public static final ADBatchAnomalyResultAction INSTANCE = new ADBatchAnomalyResultAction(); + + private ADBatchAnomalyResultAction() { + super(NAME, ADBatchAnomalyResultResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultRequest.java b/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultRequest.java index 7d2d86f90..276146840 100644 --- a/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultRequest.java +++ b/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultRequest.java @@ -19,9 +19,9 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; public class ADBatchAnomalyResultRequest extends ActionRequest { private ADTask adTask; diff --git a/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultRequest.java-e b/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultRequest.java-e new file mode 100644 index 000000000..d6e696693 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultRequest.java-e @@ -0,0 +1,65 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.Strings; + +public class ADBatchAnomalyResultRequest extends ActionRequest { + private ADTask adTask; + + public ADBatchAnomalyResultRequest(StreamInput in) throws IOException { + super(in); + adTask = new ADTask(in); + } + + public ADBatchAnomalyResultRequest(ADTask adTask) { + super(); + this.adTask = adTask; + } + + public ADTask getAdTask() { + return adTask; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + adTask.writeTo(out); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (Strings.isEmpty(adTask.getTaskId())) { + validationException = addValidationError("Task id can't be null", validationException); + } + if (adTask.getDetectionDateRange() == null) { + validationException = addValidationError("Detection date range can't be null for batch task", validationException); + } + AnomalyDetector detector = adTask.getDetector(); + if (detector == null) { + validationException = addValidationError("Detector can't be null", validationException); + } + return validationException; + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultResponse.java b/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultResponse.java index effe18692..608b53aa9 100644 --- a/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultResponse.java +++ b/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultResponse.java @@ -14,8 +14,8 @@ import java.io.IOException; import org.opensearch.action.ActionResponse; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; public class ADBatchAnomalyResultResponse extends ActionResponse { public String nodeId; diff --git a/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultResponse.java-e b/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultResponse.java-e new file mode 100644 index 000000000..608b53aa9 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultResponse.java-e @@ -0,0 +1,49 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class ADBatchAnomalyResultResponse extends ActionResponse { + public String nodeId; + public boolean runTaskRemotely; + + public ADBatchAnomalyResultResponse(String nodeId, boolean runTaskRemotely) { + this.nodeId = nodeId; + this.runTaskRemotely = runTaskRemotely; + } + + public ADBatchAnomalyResultResponse(StreamInput in) throws IOException { + super(in); + nodeId = in.readString(); + runTaskRemotely = in.readBoolean(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(nodeId); + out.writeBoolean(runTaskRemotely); + } + + public String getNodeId() { + return nodeId; + } + + public boolean isRunTaskRemotely() { + return runTaskRemotely; + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportAction.java-e new file mode 100644 index 000000000..9105ece61 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportAction.java-e @@ -0,0 +1,42 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.task.ADBatchTaskRunner; +import org.opensearch.common.inject.Inject; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +public class ADBatchAnomalyResultTransportAction extends HandledTransportAction { + + private final TransportService transportService; + private final ADBatchTaskRunner adBatchTaskRunner; + + @Inject + public ADBatchAnomalyResultTransportAction( + TransportService transportService, + ActionFilters actionFilters, + ADBatchTaskRunner adBatchTaskRunner + ) { + super(ADBatchAnomalyResultAction.NAME, transportService, actionFilters, ADBatchAnomalyResultRequest::new); + this.transportService = transportService; + this.adBatchTaskRunner = adBatchTaskRunner; + } + + @Override + protected void doExecute(Task task, ADBatchAnomalyResultRequest request, ActionListener actionListener) { + adBatchTaskRunner.run(request.getAdTask(), transportService, actionListener); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADBatchTaskRemoteExecutionAction.java-e b/src/main/java/org/opensearch/ad/transport/ADBatchTaskRemoteExecutionAction.java-e new file mode 100644 index 000000000..d865ec14c --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADBatchTaskRemoteExecutionAction.java-e @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.constant.ADCommonName.AD_TASK_REMOTE; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; + +public class ADBatchTaskRemoteExecutionAction extends ActionType { + public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detector/" + AD_TASK_REMOTE; + public static final ADBatchTaskRemoteExecutionAction INSTANCE = new ADBatchTaskRemoteExecutionAction(); + + private ADBatchTaskRemoteExecutionAction() { + super(NAME, ADBatchAnomalyResultResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/ADBatchTaskRemoteExecutionTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/ADBatchTaskRemoteExecutionTransportAction.java-e new file mode 100644 index 000000000..736262865 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADBatchTaskRemoteExecutionTransportAction.java-e @@ -0,0 +1,43 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.task.ADBatchTaskRunner; +import org.opensearch.common.inject.Inject; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +public class ADBatchTaskRemoteExecutionTransportAction extends + HandledTransportAction { + + private final ADBatchTaskRunner adBatchTaskRunner; + private final TransportService transportService; + + @Inject + public ADBatchTaskRemoteExecutionTransportAction( + ActionFilters actionFilters, + TransportService transportService, + ADBatchTaskRunner adBatchTaskRunner + ) { + super(ADBatchTaskRemoteExecutionAction.NAME, transportService, actionFilters, ADBatchAnomalyResultRequest::new); + this.adBatchTaskRunner = adBatchTaskRunner; + this.transportService = transportService; + } + + @Override + protected void doExecute(Task task, ADBatchAnomalyResultRequest request, ActionListener listener) { + adBatchTaskRunner.startADBatchTaskOnWorkerNode(request.getAdTask(), true, transportService, listener); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADCancelTaskAction.java-e b/src/main/java/org/opensearch/ad/transport/ADCancelTaskAction.java-e new file mode 100644 index 000000000..31f20fa00 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADCancelTaskAction.java-e @@ -0,0 +1,28 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.constant.ADCommonName.CANCEL_TASK; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; + +public class ADCancelTaskAction extends ActionType { + + public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detectors/" + CANCEL_TASK; + public static final ADCancelTaskAction INSTANCE = new ADCancelTaskAction(); + + private ADCancelTaskAction() { + super(NAME, ADCancelTaskResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/ADCancelTaskNodeRequest.java b/src/main/java/org/opensearch/ad/transport/ADCancelTaskNodeRequest.java index ec71120bb..61157c5f7 100644 --- a/src/main/java/org/opensearch/ad/transport/ADCancelTaskNodeRequest.java +++ b/src/main/java/org/opensearch/ad/transport/ADCancelTaskNodeRequest.java @@ -13,8 +13,8 @@ import java.io.IOException; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.transport.TransportRequest; public class ADCancelTaskNodeRequest extends TransportRequest { diff --git a/src/main/java/org/opensearch/ad/transport/ADCancelTaskNodeRequest.java-e b/src/main/java/org/opensearch/ad/transport/ADCancelTaskNodeRequest.java-e new file mode 100644 index 000000000..61157c5f7 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADCancelTaskNodeRequest.java-e @@ -0,0 +1,67 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + +public class ADCancelTaskNodeRequest extends TransportRequest { + private String detectorId; + private String detectorTaskId; + private String userName; + private String reason; + + public ADCancelTaskNodeRequest(StreamInput in) throws IOException { + super(in); + this.detectorId = in.readOptionalString(); + this.userName = in.readOptionalString(); + if (in.available() > 0) { + this.detectorTaskId = in.readOptionalString(); + this.reason = in.readOptionalString(); + } + } + + public ADCancelTaskNodeRequest(ADCancelTaskRequest request) { + this.detectorId = request.getId(); + this.detectorTaskId = request.getDetectorTaskId(); + this.userName = request.getUserName(); + this.reason = request.getReason(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeOptionalString(detectorId); + out.writeOptionalString(userName); + out.writeOptionalString(detectorTaskId); + out.writeOptionalString(reason); + } + + public String getId() { + return detectorId; + } + + public String getDetectorTaskId() { + return detectorTaskId; + } + + public String getUserName() { + return userName; + } + + public String getReason() { + return reason; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADCancelTaskNodeResponse.java b/src/main/java/org/opensearch/ad/transport/ADCancelTaskNodeResponse.java index e2a07560f..1a09f7c16 100644 --- a/src/main/java/org/opensearch/ad/transport/ADCancelTaskNodeResponse.java +++ b/src/main/java/org/opensearch/ad/transport/ADCancelTaskNodeResponse.java @@ -16,8 +16,8 @@ import org.opensearch.action.support.nodes.BaseNodeResponse; import org.opensearch.ad.task.ADTaskCancellationState; import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; public class ADCancelTaskNodeResponse extends BaseNodeResponse { diff --git a/src/main/java/org/opensearch/ad/transport/ADCancelTaskNodeResponse.java-e b/src/main/java/org/opensearch/ad/transport/ADCancelTaskNodeResponse.java-e new file mode 100644 index 000000000..1a09f7c16 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADCancelTaskNodeResponse.java-e @@ -0,0 +1,49 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.ad.task.ADTaskCancellationState; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class ADCancelTaskNodeResponse extends BaseNodeResponse { + + private ADTaskCancellationState state; + + public ADCancelTaskNodeResponse(DiscoveryNode node, ADTaskCancellationState state) { + super(node); + this.state = state; + } + + public ADCancelTaskNodeResponse(StreamInput in) throws IOException { + super(in); + this.state = in.readEnum(ADTaskCancellationState.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeEnum(state); + } + + public static ADCancelTaskNodeResponse readNodeResponse(StreamInput in) throws IOException { + return new ADCancelTaskNodeResponse(in); + } + + public ADTaskCancellationState getState() { + return state; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADCancelTaskRequest.java b/src/main/java/org/opensearch/ad/transport/ADCancelTaskRequest.java index 9b07add33..ddfbd6a53 100644 --- a/src/main/java/org/opensearch/ad/transport/ADCancelTaskRequest.java +++ b/src/main/java/org/opensearch/ad/transport/ADCancelTaskRequest.java @@ -19,9 +19,9 @@ import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; public class ADCancelTaskRequest extends BaseNodesRequest { diff --git a/src/main/java/org/opensearch/ad/transport/ADCancelTaskRequest.java-e b/src/main/java/org/opensearch/ad/transport/ADCancelTaskRequest.java-e new file mode 100644 index 000000000..bf2df1212 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADCancelTaskRequest.java-e @@ -0,0 +1,88 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; + +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.Strings; + +public class ADCancelTaskRequest extends BaseNodesRequest { + + private String detectorId; + private String detectorTaskId; + private String userName; + private String reason; + + public ADCancelTaskRequest(StreamInput in) throws IOException { + super(in); + this.detectorId = in.readOptionalString(); + this.userName = in.readOptionalString(); + if (in.available() > 0) { + this.detectorTaskId = in.readOptionalString(); + this.reason = in.readOptionalString(); + } + } + + public ADCancelTaskRequest(String detectorId, String detectorTaskId, String userName, DiscoveryNode... nodes) { + this(detectorId, detectorTaskId, userName, null, nodes); + } + + public ADCancelTaskRequest(String detectorId, String detectorTaskId, String userName, String reason, DiscoveryNode... nodes) { + super(nodes); + this.detectorId = detectorId; + this.detectorTaskId = detectorTaskId; + this.userName = userName; + this.reason = reason; + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (Strings.isEmpty(detectorId)) { + validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + } + return validationException; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeOptionalString(detectorId); + out.writeOptionalString(userName); + out.writeOptionalString(detectorTaskId); + out.writeOptionalString(reason); + } + + public String getId() { + return detectorId; + } + + public String getDetectorTaskId() { + return detectorTaskId; + } + + public String getUserName() { + return userName; + } + + public String getReason() { + return reason; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADCancelTaskResponse.java b/src/main/java/org/opensearch/ad/transport/ADCancelTaskResponse.java index 3fec05caa..4e0bf5464 100644 --- a/src/main/java/org/opensearch/ad/transport/ADCancelTaskResponse.java +++ b/src/main/java/org/opensearch/ad/transport/ADCancelTaskResponse.java @@ -17,8 +17,8 @@ import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.nodes.BaseNodesResponse; import org.opensearch.cluster.ClusterName; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; public class ADCancelTaskResponse extends BaseNodesResponse { diff --git a/src/main/java/org/opensearch/ad/transport/ADCancelTaskResponse.java-e b/src/main/java/org/opensearch/ad/transport/ADCancelTaskResponse.java-e new file mode 100644 index 000000000..4e0bf5464 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADCancelTaskResponse.java-e @@ -0,0 +1,42 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.nodes.BaseNodesResponse; +import org.opensearch.cluster.ClusterName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class ADCancelTaskResponse extends BaseNodesResponse { + + public ADCancelTaskResponse(StreamInput in) throws IOException { + super(new ClusterName(in), in.readList(ADCancelTaskNodeResponse::readNodeResponse), in.readList(FailedNodeException::new)); + } + + public ADCancelTaskResponse(ClusterName clusterName, List nodes, List failures) { + super(clusterName, nodes, failures); + } + + @Override + public void writeNodesTo(StreamOutput out, List nodes) throws IOException { + out.writeList(nodes); + } + + @Override + public List readNodesFrom(StreamInput in) throws IOException { + return in.readList(ADCancelTaskNodeResponse::readNodeResponse); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADCancelTaskTransportAction.java b/src/main/java/org/opensearch/ad/transport/ADCancelTaskTransportAction.java index 801910f96..03d0a5861 100644 --- a/src/main/java/org/opensearch/ad/transport/ADCancelTaskTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADCancelTaskTransportAction.java @@ -26,7 +26,7 @@ import org.opensearch.ad.task.ADTaskManager; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; diff --git a/src/main/java/org/opensearch/ad/transport/ADCancelTaskTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/ADCancelTaskTransportAction.java-e new file mode 100644 index 000000000..03d0a5861 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADCancelTaskTransportAction.java-e @@ -0,0 +1,89 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.constant.ADCommonMessages.HISTORICAL_ANALYSIS_CANCELLED; + +import java.io.IOException; +import java.util.List; +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.ad.task.ADTaskCancellationState; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class ADCancelTaskTransportAction extends + TransportNodesAction { + private final Logger logger = LogManager.getLogger(ADCancelTaskTransportAction.class); + private ADTaskManager adTaskManager; + + @Inject + public ADCancelTaskTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + ADTaskManager adTaskManager + ) { + super( + ADCancelTaskAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + ADCancelTaskRequest::new, + ADCancelTaskNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + ADCancelTaskNodeResponse.class + ); + this.adTaskManager = adTaskManager; + } + + @Override + protected ADCancelTaskResponse newResponse( + ADCancelTaskRequest request, + List responses, + List failures + ) { + return new ADCancelTaskResponse(clusterService.getClusterName(), responses, failures); + } + + @Override + protected ADCancelTaskNodeRequest newNodeRequest(ADCancelTaskRequest request) { + return new ADCancelTaskNodeRequest(request); + } + + @Override + protected ADCancelTaskNodeResponse newNodeResponse(StreamInput in) throws IOException { + return new ADCancelTaskNodeResponse(in); + } + + @Override + protected ADCancelTaskNodeResponse nodeOperation(ADCancelTaskNodeRequest request) { + String userName = request.getUserName(); + String detectorId = request.getId(); + String detectorTaskId = request.getDetectorTaskId(); + String reason = Optional.ofNullable(request.getReason()).orElse(HISTORICAL_ANALYSIS_CANCELLED); + ADTaskCancellationState state = adTaskManager.cancelLocalTaskByDetectorId(detectorId, detectorTaskId, reason, userName); + logger.debug("Cancelled AD task for detector: {}", request.getId()); + return new ADCancelTaskNodeResponse(clusterService.localNode(), state); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADResultBulkAction.java-e b/src/main/java/org/opensearch/ad/transport/ADResultBulkAction.java-e new file mode 100644 index 000000000..041d543b7 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADResultBulkAction.java-e @@ -0,0 +1,33 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; +import org.opensearch.common.settings.Settings; +import org.opensearch.transport.TransportRequestOptions; + +public class ADResultBulkAction extends ActionType { + + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "write/bulk"; + public static final ADResultBulkAction INSTANCE = new ADResultBulkAction(); + + private ADResultBulkAction() { + super(NAME, ADResultBulkResponse::new); + } + + @Override + public TransportRequestOptions transportOptions(Settings settings) { + return TransportRequestOptions.builder().withType(TransportRequestOptions.Type.BULK).build(); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADResultBulkRequest.java b/src/main/java/org/opensearch/ad/transport/ADResultBulkRequest.java index 1bc00a56b..f5f361f69 100644 --- a/src/main/java/org/opensearch/ad/transport/ADResultBulkRequest.java +++ b/src/main/java/org/opensearch/ad/transport/ADResultBulkRequest.java @@ -19,9 +19,9 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.action.ValidateActions; import org.opensearch.ad.ratelimit.ResultWriteRequest; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; public class ADResultBulkRequest extends ActionRequest implements Writeable { private final List anomalyResults; diff --git a/src/main/java/org/opensearch/ad/transport/ADResultBulkRequest.java-e b/src/main/java/org/opensearch/ad/transport/ADResultBulkRequest.java-e new file mode 100644 index 000000000..f5f361f69 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADResultBulkRequest.java-e @@ -0,0 +1,84 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.ValidateActions; +import org.opensearch.ad.ratelimit.ResultWriteRequest; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; + +public class ADResultBulkRequest extends ActionRequest implements Writeable { + private final List anomalyResults; + static final String NO_REQUESTS_ADDED_ERR = "no requests added"; + + public ADResultBulkRequest() { + anomalyResults = new ArrayList<>(); + } + + public ADResultBulkRequest(StreamInput in) throws IOException { + super(in); + int size = in.readVInt(); + anomalyResults = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + anomalyResults.add(new ResultWriteRequest(in)); + } + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (anomalyResults.isEmpty()) { + validationException = ValidateActions.addValidationError(NO_REQUESTS_ADDED_ERR, validationException); + } + return validationException; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeVInt(anomalyResults.size()); + for (ResultWriteRequest result : anomalyResults) { + result.writeTo(out); + } + } + + /** + * + * @return all of the results to send + */ + public List getAnomalyResults() { + return anomalyResults; + } + + /** + * Add result to send + * @param resultWriteRequest The result write request + */ + public void add(ResultWriteRequest resultWriteRequest) { + anomalyResults.add(resultWriteRequest); + } + + /** + * + * @return total index requests + */ + public int numberOfActions() { + return anomalyResults.size(); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADResultBulkResponse.java b/src/main/java/org/opensearch/ad/transport/ADResultBulkResponse.java index 476f42482..8206d908e 100644 --- a/src/main/java/org/opensearch/ad/transport/ADResultBulkResponse.java +++ b/src/main/java/org/opensearch/ad/transport/ADResultBulkResponse.java @@ -18,8 +18,8 @@ import org.opensearch.action.ActionResponse; import org.opensearch.action.index.IndexRequest; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; public class ADResultBulkResponse extends ActionResponse { public static final String RETRY_REQUESTS_JSON_KEY = "retry_requests"; diff --git a/src/main/java/org/opensearch/ad/transport/ADResultBulkResponse.java-e b/src/main/java/org/opensearch/ad/transport/ADResultBulkResponse.java-e new file mode 100644 index 000000000..8206d908e --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADResultBulkResponse.java-e @@ -0,0 +1,70 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import org.opensearch.action.ActionResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class ADResultBulkResponse extends ActionResponse { + public static final String RETRY_REQUESTS_JSON_KEY = "retry_requests"; + + private List retryRequests; + + /** + * + * @param retryRequests a list of requests to retry + */ + public ADResultBulkResponse(List retryRequests) { + this.retryRequests = retryRequests; + } + + public ADResultBulkResponse() { + this.retryRequests = null; + } + + public ADResultBulkResponse(StreamInput in) throws IOException { + int size = in.readInt(); + if (size > 0) { + retryRequests = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + retryRequests.add(new IndexRequest(in)); + } + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (retryRequests == null || retryRequests.size() == 0) { + out.writeInt(0); + } else { + out.writeInt(retryRequests.size()); + for (IndexRequest result : retryRequests) { + result.writeTo(out); + } + } + } + + public boolean hasFailures() { + return retryRequests != null && retryRequests.size() > 0; + } + + public Optional> getRetryRequests() { + return Optional.ofNullable(retryRequests); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADResultBulkTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/ADResultBulkTransportAction.java-e new file mode 100644 index 000000000..bf9387760 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADResultBulkTransportAction.java-e @@ -0,0 +1,140 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_INDEX_PRESSURE_HARD_LIMIT; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_INDEX_PRESSURE_SOFT_LIMIT; +import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; +import static org.opensearch.core.index.IndexingPressure.MAX_INDEXING_BYTES; + +import java.io.IOException; +import java.util.List; +import java.util.Random; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.bulk.BulkAction; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.ratelimit.ResultWriteRequest; +import org.opensearch.ad.util.BulkUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.index.IndexingPressure; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +public class ADResultBulkTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(ADResultBulkTransportAction.class); + private IndexingPressure indexingPressure; + private final long primaryAndCoordinatingLimits; + private float softLimit; + private float hardLimit; + private String indexName; + private Client client; + private Random random; + + @Inject + public ADResultBulkTransportAction( + TransportService transportService, + ActionFilters actionFilters, + IndexingPressure indexingPressure, + Settings settings, + ClusterService clusterService, + Client client + ) { + super(ADResultBulkAction.NAME, transportService, actionFilters, ADResultBulkRequest::new, ThreadPool.Names.SAME); + this.indexingPressure = indexingPressure; + this.primaryAndCoordinatingLimits = MAX_INDEXING_BYTES.get(settings).getBytes(); + this.softLimit = AD_INDEX_PRESSURE_SOFT_LIMIT.get(settings); + this.hardLimit = AD_INDEX_PRESSURE_HARD_LIMIT.get(settings); + this.indexName = ADCommonName.ANOMALY_RESULT_INDEX_ALIAS; + this.client = client; + clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_INDEX_PRESSURE_SOFT_LIMIT, it -> softLimit = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_INDEX_PRESSURE_HARD_LIMIT, it -> hardLimit = it); + // random seed is 42. Can be any number + this.random = new Random(42); + } + + @Override + protected void doExecute(Task task, ADResultBulkRequest request, ActionListener listener) { + // Concurrent indexing memory limit = 10% of heap + // indexing pressure = indexing bytes / indexing limit + // Write all until index pressure (global indexing memory pressure) is less than 80% of 10% of heap. Otherwise, index + // all non-zero anomaly grade index requests and index zero anomaly grade index requests with probability (1 - index pressure). + long totalBytes = indexingPressure.getCurrentCombinedCoordinatingAndPrimaryBytes() + indexingPressure.getCurrentReplicaBytes(); + float indexingPressurePercent = (float) totalBytes / primaryAndCoordinatingLimits; + List results = request.getAnomalyResults(); + + if (results == null || results.size() < 1) { + listener.onResponse(new ADResultBulkResponse()); + } + + BulkRequest bulkRequest = new BulkRequest(); + + if (indexingPressurePercent <= softLimit) { + for (ResultWriteRequest resultWriteRequest : results) { + addResult(bulkRequest, resultWriteRequest.getResult(), resultWriteRequest.getCustomResultIndex()); + } + } else if (indexingPressurePercent <= hardLimit) { + // exceed soft limit (60%) but smaller than hard limit (90%) + float acceptProbability = 1 - indexingPressurePercent; + for (ResultWriteRequest resultWriteRequest : results) { + AnomalyResult result = resultWriteRequest.getResult(); + if (result.isHighPriority() || random.nextFloat() < acceptProbability) { + addResult(bulkRequest, result, resultWriteRequest.getCustomResultIndex()); + } + } + } else { + // if exceeding hard limit, only index non-zero grade or error result + for (ResultWriteRequest resultWriteRequest : results) { + AnomalyResult result = resultWriteRequest.getResult(); + if (result.isHighPriority()) { + addResult(bulkRequest, result, resultWriteRequest.getCustomResultIndex()); + } + } + } + + if (bulkRequest.numberOfActions() > 0) { + client.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(bulkResponse -> { + List failedRequests = BulkUtil.getFailedIndexRequest(bulkRequest, bulkResponse); + listener.onResponse(new ADResultBulkResponse(failedRequests)); + }, e -> { + LOG.error("Failed to bulk index AD result", e); + listener.onFailure(e); + })); + } else { + listener.onResponse(new ADResultBulkResponse()); + } + } + + private void addResult(BulkRequest bulkRequest, AnomalyResult result, String resultIndex) { + String index = resultIndex == null ? indexName : resultIndex; + try (XContentBuilder builder = jsonBuilder()) { + IndexRequest indexRequest = new IndexRequest(index).source(result.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)); + bulkRequest.add(indexRequest); + } catch (IOException e) { + LOG.error("Failed to prepare bulk index request for index " + index, e); + } + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsNodeRequest.java b/src/main/java/org/opensearch/ad/transport/ADStatsNodeRequest.java index a018cf87b..099bc7db1 100644 --- a/src/main/java/org/opensearch/ad/transport/ADStatsNodeRequest.java +++ b/src/main/java/org/opensearch/ad/transport/ADStatsNodeRequest.java @@ -13,8 +13,8 @@ import java.io.IOException; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.transport.TransportRequest; /** diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsNodeRequest.java-e b/src/main/java/org/opensearch/ad/transport/ADStatsNodeRequest.java-e new file mode 100644 index 000000000..099bc7db1 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADStatsNodeRequest.java-e @@ -0,0 +1,61 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + +/** + * ADStatsNodeRequest to get a nodes stat + */ +public class ADStatsNodeRequest extends TransportRequest { + private ADStatsRequest request; + + /** + * Constructor + */ + public ADStatsNodeRequest() { + super(); + } + + public ADStatsNodeRequest(StreamInput in) throws IOException { + super(in); + this.request = new ADStatsRequest(in); + } + + /** + * Constructor + * + * @param request ADStatsRequest + */ + public ADStatsNodeRequest(ADStatsRequest request) { + this.request = request; + } + + /** + * Get ADStatsRequest + * + * @return ADStatsRequest for this node + */ + public ADStatsRequest getADStatsRequest() { + return request; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + request.writeTo(out); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsNodeResponse.java b/src/main/java/org/opensearch/ad/transport/ADStatsNodeResponse.java index b2fb7d595..f5296cf17 100644 --- a/src/main/java/org/opensearch/ad/transport/ADStatsNodeResponse.java +++ b/src/main/java/org/opensearch/ad/transport/ADStatsNodeResponse.java @@ -16,8 +16,8 @@ import org.opensearch.action.support.nodes.BaseNodeResponse; import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsNodeResponse.java-e b/src/main/java/org/opensearch/ad/transport/ADStatsNodeResponse.java-e new file mode 100644 index 000000000..f5296cf17 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADStatsNodeResponse.java-e @@ -0,0 +1,96 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentFragment; +import org.opensearch.core.xcontent.XContentBuilder; + +/** + * ADStatsNodeResponse + */ +public class ADStatsNodeResponse extends BaseNodeResponse implements ToXContentFragment { + + private Map statsMap; + + /** + * Constructor + * + * @param in StreamInput + * @throws IOException throws an IO exception if the StreamInput cannot be read from + */ + public ADStatsNodeResponse(StreamInput in) throws IOException { + super(in); + this.statsMap = in.readMap(StreamInput::readString, StreamInput::readGenericValue); + } + + /** + * Constructor + * + * @param node node + * @param statsToValues Mapping of stat name to value + */ + public ADStatsNodeResponse(DiscoveryNode node, Map statsToValues) { + super(node); + this.statsMap = statsToValues; + } + + /** + * Creates a new ADStatsNodeResponse object and reads in the stats from an input stream + * + * @param in StreamInput to read from + * @return ADStatsNodeResponse object corresponding to the input stream + * @throws IOException throws an IO exception if the StreamInput cannot be read from + */ + public static ADStatsNodeResponse readStats(StreamInput in) throws IOException { + + return new ADStatsNodeResponse(in); + } + + /** + * getStatsMap + * + * @return map of stats + */ + public Map getStatsMap() { + return statsMap; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeMap(statsMap, StreamOutput::writeString, StreamOutput::writeGenericValue); + } + + /** + * Converts statsMap to xContent + * + * @param builder XContentBuilder + * @param params Params + * @return XContentBuilder + * @throws IOException thrown by builder for invalid field + */ + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + for (String stat : statsMap.keySet()) { + builder.field(stat, statsMap.get(stat)); + } + + return builder; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsNodesAction.java-e b/src/main/java/org/opensearch/ad/transport/ADStatsNodesAction.java-e new file mode 100644 index 000000000..f6f39ab85 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADStatsNodesAction.java-e @@ -0,0 +1,33 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; + +/** + * ADStatsNodesAction class + */ +public class ADStatsNodesAction extends ActionType { + + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "stats/nodes"; + public static final ADStatsNodesAction INSTANCE = new ADStatsNodesAction(); + + /** + * Constructor + */ + private ADStatsNodesAction() { + super(NAME, ADStatsNodesResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsNodesResponse.java b/src/main/java/org/opensearch/ad/transport/ADStatsNodesResponse.java index a814f2c93..2dbdff03c 100644 --- a/src/main/java/org/opensearch/ad/transport/ADStatsNodesResponse.java +++ b/src/main/java/org/opensearch/ad/transport/ADStatsNodesResponse.java @@ -18,8 +18,8 @@ import org.opensearch.action.support.nodes.BaseNodesResponse; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsNodesResponse.java-e b/src/main/java/org/opensearch/ad/transport/ADStatsNodesResponse.java-e new file mode 100644 index 000000000..2dbdff03c --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADStatsNodesResponse.java-e @@ -0,0 +1,84 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.nodes.BaseNodesResponse; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +/** + * ADStatsNodesResponse consists of the aggregated responses from the nodes + */ +public class ADStatsNodesResponse extends BaseNodesResponse implements ToXContentObject { + + private static final String NODES_KEY = "nodes"; + + /** + * Constructor + * + * @param in StreamInput + * @throws IOException thrown when unable to read from stream + */ + public ADStatsNodesResponse(StreamInput in) throws IOException { + super(new ClusterName(in), in.readList(ADStatsNodeResponse::readStats), in.readList(FailedNodeException::new)); + } + + /** + * Constructor + * + * @param clusterName name of cluster + * @param nodes List of ADStatsNodeResponses from nodes + * @param failures List of failures from nodes + */ + public ADStatsNodesResponse(ClusterName clusterName, List nodes, List failures) { + super(clusterName, nodes, failures); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } + + @Override + public void writeNodesTo(StreamOutput out, List nodes) throws IOException { + out.writeList(nodes); + } + + @Override + public List readNodesFrom(StreamInput in) throws IOException { + return in.readList(ADStatsNodeResponse::readStats); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + String nodeId; + DiscoveryNode node; + builder.startObject(NODES_KEY); + for (ADStatsNodeResponse adStats : getNodes()) { + node = adStats.getNode(); + nodeId = node.getId(); + builder.startObject(nodeId); + adStats.toXContent(builder, params); + builder.endObject(); + } + builder.endObject(); + return builder; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsNodesTransportAction.java b/src/main/java/org/opensearch/ad/transport/ADStatsNodesTransportAction.java index 66e7ae0cb..17a81da0a 100644 --- a/src/main/java/org/opensearch/ad/transport/ADStatsNodesTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADStatsNodesTransportAction.java @@ -25,7 +25,7 @@ import org.opensearch.ad.task.ADTaskManager; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.monitor.jvm.JvmService; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsNodesTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/ADStatsNodesTransportAction.java-e new file mode 100644 index 000000000..17a81da0a --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADStatsNodesTransportAction.java-e @@ -0,0 +1,131 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.ad.stats.ADStats; +import org.opensearch.ad.stats.InternalStatNames; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.monitor.jvm.JvmService; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +/** + * ADStatsNodesTransportAction contains the logic to extract the stats from the nodes + */ +public class ADStatsNodesTransportAction extends + TransportNodesAction { + + private ADStats adStats; + private final JvmService jvmService; + private final ADTaskManager adTaskManager; + + /** + * Constructor + * + * @param threadPool ThreadPool to use + * @param clusterService ClusterService + * @param transportService TransportService + * @param actionFilters Action Filters + * @param adStats ADStats object + * @param jvmService ES JVM Service + * @param adTaskManager AD task manager + */ + @Inject + public ADStatsNodesTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + ADStats adStats, + JvmService jvmService, + ADTaskManager adTaskManager + ) { + super( + ADStatsNodesAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + ADStatsRequest::new, + ADStatsNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + ADStatsNodeResponse.class + ); + this.adStats = adStats; + this.jvmService = jvmService; + this.adTaskManager = adTaskManager; + } + + @Override + protected ADStatsNodesResponse newResponse( + ADStatsRequest request, + List responses, + List failures + ) { + return new ADStatsNodesResponse(clusterService.getClusterName(), responses, failures); + } + + @Override + protected ADStatsNodeRequest newNodeRequest(ADStatsRequest request) { + return new ADStatsNodeRequest(request); + } + + @Override + protected ADStatsNodeResponse newNodeResponse(StreamInput in) throws IOException { + return new ADStatsNodeResponse(in); + } + + @Override + protected ADStatsNodeResponse nodeOperation(ADStatsNodeRequest request) { + return createADStatsNodeResponse(request.getADStatsRequest()); + } + + private ADStatsNodeResponse createADStatsNodeResponse(ADStatsRequest adStatsRequest) { + Map statValues = new HashMap<>(); + Set statsToBeRetrieved = adStatsRequest.getStatsToBeRetrieved(); + + if (statsToBeRetrieved.contains(InternalStatNames.JVM_HEAP_USAGE.getName())) { + long heapUsedPercent = jvmService.stats().getMem().getHeapUsedPercent(); + statValues.put(InternalStatNames.JVM_HEAP_USAGE.getName(), heapUsedPercent); + } + + if (statsToBeRetrieved.contains(InternalStatNames.AD_USED_BATCH_TASK_SLOT_COUNT.getName())) { + int usedTaskSlot = adTaskManager.getLocalAdUsedBatchTaskSlot(); + statValues.put(InternalStatNames.AD_USED_BATCH_TASK_SLOT_COUNT.getName(), usedTaskSlot); + } + + if (statsToBeRetrieved.contains(InternalStatNames.AD_DETECTOR_ASSIGNED_BATCH_TASK_SLOT_COUNT.getName())) { + int assignedBatchTaskSlot = adTaskManager.getLocalAdAssignedBatchTaskSlot(); + statValues.put(InternalStatNames.AD_DETECTOR_ASSIGNED_BATCH_TASK_SLOT_COUNT.getName(), assignedBatchTaskSlot); + } + + for (String statName : adStats.getNodeStats().keySet()) { + if (statsToBeRetrieved.contains(statName)) { + statValues.put(statName, adStats.getStats().get(statName).getValue()); + } + } + + return new ADStatsNodeResponse(clusterService.localNode(), statValues); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsRequest.java b/src/main/java/org/opensearch/ad/transport/ADStatsRequest.java index 204553089..32301e526 100644 --- a/src/main/java/org/opensearch/ad/transport/ADStatsRequest.java +++ b/src/main/java/org/opensearch/ad/transport/ADStatsRequest.java @@ -17,8 +17,8 @@ import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; /** * ADStatsRequest implements a request to obtain stats about the AD plugin diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsRequest.java-e b/src/main/java/org/opensearch/ad/transport/ADStatsRequest.java-e new file mode 100644 index 000000000..32301e526 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADStatsRequest.java-e @@ -0,0 +1,103 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Set; + +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +/** + * ADStatsRequest implements a request to obtain stats about the AD plugin + */ +public class ADStatsRequest extends BaseNodesRequest { + + /** + * Key indicating all stats should be retrieved + */ + public static final String ALL_STATS_KEY = "_all"; + + private Set statsToBeRetrieved; + + public ADStatsRequest(StreamInput in) throws IOException { + super(in); + statsToBeRetrieved = in.readSet(StreamInput::readString); + } + + /** + * Constructor + * + * @param nodeIds nodeIds of nodes' stats to be retrieved + */ + public ADStatsRequest(String... nodeIds) { + super(nodeIds); + statsToBeRetrieved = new HashSet<>(); + } + + /** + * Constructor + * + * @param nodes nodes of nodes' stats to be retrieved + */ + public ADStatsRequest(DiscoveryNode... nodes) { + super(nodes); + statsToBeRetrieved = new HashSet<>(); + } + + /** + * Adds a stat to the set of stats to be retrieved + * + * @param stat name of the stat + */ + public void addStat(String stat) { + statsToBeRetrieved.add(stat); + } + + /** + * Add all stats to be retrieved + * + * @param statsToBeAdded set of stats to be retrieved + */ + public void addAll(Set statsToBeAdded) { + statsToBeRetrieved.addAll(statsToBeAdded); + } + + /** + * Remove all stats from retrieval set + */ + public void clear() { + statsToBeRetrieved.clear(); + } + + /** + * Get the set that tracks which stats should be retrieved + * + * @return the set that contains the stat names marked for retrieval + */ + public Set getStatsToBeRetrieved() { + return statsToBeRetrieved; + } + + public void readFrom(StreamInput in) throws IOException { + statsToBeRetrieved = in.readSet(StreamInput::readString); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeStringCollection(statsToBeRetrieved); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADTaskProfileAction.java-e b/src/main/java/org/opensearch/ad/transport/ADTaskProfileAction.java-e new file mode 100644 index 000000000..f2b198d1c --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADTaskProfileAction.java-e @@ -0,0 +1,28 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.constant.ADCommonName.AD_TASK; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; + +public class ADTaskProfileAction extends ActionType { + + public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detectors/profile/" + AD_TASK; + public static final ADTaskProfileAction INSTANCE = new ADTaskProfileAction(); + + private ADTaskProfileAction() { + super(NAME, ADTaskProfileResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/ADTaskProfileNodeRequest.java b/src/main/java/org/opensearch/ad/transport/ADTaskProfileNodeRequest.java index 589a13520..a8fab2b87 100644 --- a/src/main/java/org/opensearch/ad/transport/ADTaskProfileNodeRequest.java +++ b/src/main/java/org/opensearch/ad/transport/ADTaskProfileNodeRequest.java @@ -13,8 +13,8 @@ import java.io.IOException; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.transport.TransportRequest; public class ADTaskProfileNodeRequest extends TransportRequest { diff --git a/src/main/java/org/opensearch/ad/transport/ADTaskProfileNodeRequest.java-e b/src/main/java/org/opensearch/ad/transport/ADTaskProfileNodeRequest.java-e new file mode 100644 index 000000000..a8fab2b87 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADTaskProfileNodeRequest.java-e @@ -0,0 +1,42 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + +public class ADTaskProfileNodeRequest extends TransportRequest { + private String detectorId; + + public ADTaskProfileNodeRequest(StreamInput in) throws IOException { + super(in); + this.detectorId = in.readString(); + } + + public ADTaskProfileNodeRequest(ADTaskProfileRequest request) { + this.detectorId = request.getId(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(detectorId); + } + + public String getId() { + return detectorId; + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/ADTaskProfileNodeResponse.java b/src/main/java/org/opensearch/ad/transport/ADTaskProfileNodeResponse.java index 5b74cc994..363e70be2 100644 --- a/src/main/java/org/opensearch/ad/transport/ADTaskProfileNodeResponse.java +++ b/src/main/java/org/opensearch/ad/transport/ADTaskProfileNodeResponse.java @@ -19,8 +19,8 @@ import org.opensearch.action.support.nodes.BaseNodeResponse; import org.opensearch.ad.model.ADTaskProfile; import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; public class ADTaskProfileNodeResponse extends BaseNodeResponse { private static final Logger logger = LogManager.getLogger(ADTaskProfileNodeResponse.class); diff --git a/src/main/java/org/opensearch/ad/transport/ADTaskProfileNodeResponse.java-e b/src/main/java/org/opensearch/ad/transport/ADTaskProfileNodeResponse.java-e new file mode 100644 index 000000000..363e70be2 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADTaskProfileNodeResponse.java-e @@ -0,0 +1,63 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.Version; +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.ad.model.ADTaskProfile; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class ADTaskProfileNodeResponse extends BaseNodeResponse { + private static final Logger logger = LogManager.getLogger(ADTaskProfileNodeResponse.class); + private ADTaskProfile adTaskProfile; + private Version remoteAdVersion; + + public ADTaskProfileNodeResponse(DiscoveryNode node, ADTaskProfile adTaskProfile, Version remoteAdVersion) { + super(node); + this.adTaskProfile = adTaskProfile; + this.remoteAdVersion = remoteAdVersion; + } + + public ADTaskProfileNodeResponse(StreamInput in) throws IOException { + super(in); + if (in.readBoolean()) { + adTaskProfile = new ADTaskProfile(in); + } else { + adTaskProfile = null; + } + } + + public ADTaskProfile getAdTaskProfile() { + return adTaskProfile; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + if (adTaskProfile != null && (remoteAdVersion != null || adTaskProfile.getNodeId() != null)) { + out.writeBoolean(true); + adTaskProfile.writeTo(out, remoteAdVersion); + } else { + out.writeBoolean(false); + } + } + + public static ADTaskProfileNodeResponse readNodeResponse(StreamInput in) throws IOException { + return new ADTaskProfileNodeResponse(in); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADTaskProfileRequest.java b/src/main/java/org/opensearch/ad/transport/ADTaskProfileRequest.java index 91bfa308e..7b078d215 100644 --- a/src/main/java/org/opensearch/ad/transport/ADTaskProfileRequest.java +++ b/src/main/java/org/opensearch/ad/transport/ADTaskProfileRequest.java @@ -19,9 +19,9 @@ import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; public class ADTaskProfileRequest extends BaseNodesRequest { diff --git a/src/main/java/org/opensearch/ad/transport/ADTaskProfileRequest.java-e b/src/main/java/org/opensearch/ad/transport/ADTaskProfileRequest.java-e new file mode 100644 index 000000000..62d70512c --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADTaskProfileRequest.java-e @@ -0,0 +1,58 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; + +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.Strings; + +public class ADTaskProfileRequest extends BaseNodesRequest { + + private String detectorId; + + public ADTaskProfileRequest(StreamInput in) throws IOException { + super(in); + this.detectorId = in.readString(); + } + + public ADTaskProfileRequest(String detectorId, DiscoveryNode... nodes) { + super(nodes); + this.detectorId = detectorId; + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (Strings.isEmpty(detectorId)) { + validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + } + return validationException; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(detectorId); + } + + public String getId() { + return detectorId; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADTaskProfileResponse.java b/src/main/java/org/opensearch/ad/transport/ADTaskProfileResponse.java index 1c1335552..951f362f9 100644 --- a/src/main/java/org/opensearch/ad/transport/ADTaskProfileResponse.java +++ b/src/main/java/org/opensearch/ad/transport/ADTaskProfileResponse.java @@ -17,8 +17,8 @@ import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.nodes.BaseNodesResponse; import org.opensearch.cluster.ClusterName; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; public class ADTaskProfileResponse extends BaseNodesResponse { diff --git a/src/main/java/org/opensearch/ad/transport/ADTaskProfileResponse.java-e b/src/main/java/org/opensearch/ad/transport/ADTaskProfileResponse.java-e new file mode 100644 index 000000000..951f362f9 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADTaskProfileResponse.java-e @@ -0,0 +1,43 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.nodes.BaseNodesResponse; +import org.opensearch.cluster.ClusterName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class ADTaskProfileResponse extends BaseNodesResponse { + + public ADTaskProfileResponse(StreamInput in) throws IOException { + super(new ClusterName(in), in.readList(ADTaskProfileNodeResponse::readNodeResponse), in.readList(FailedNodeException::new)); + } + + public ADTaskProfileResponse(ClusterName clusterName, List nodes, List failures) { + super(clusterName, nodes, failures); + } + + @Override + public void writeNodesTo(StreamOutput out, List nodes) throws IOException { + out.writeList(nodes); + } + + @Override + public List readNodesFrom(StreamInput in) throws IOException { + return in.readList(ADTaskProfileNodeResponse::readNodeResponse); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/ADTaskProfileTransportAction.java b/src/main/java/org/opensearch/ad/transport/ADTaskProfileTransportAction.java index 393973d7b..6902d6de8 100644 --- a/src/main/java/org/opensearch/ad/transport/ADTaskProfileTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADTaskProfileTransportAction.java @@ -23,7 +23,7 @@ import org.opensearch.ad.task.ADTaskManager; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; diff --git a/src/main/java/org/opensearch/ad/transport/ADTaskProfileTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/ADTaskProfileTransportAction.java-e new file mode 100644 index 000000000..6902d6de8 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADTaskProfileTransportAction.java-e @@ -0,0 +1,86 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.Version; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.model.ADTaskProfile; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class ADTaskProfileTransportAction extends + TransportNodesAction { + + private ADTaskManager adTaskManager; + private HashRing hashRing; + + @Inject + public ADTaskProfileTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + ADTaskManager adTaskManager, + HashRing hashRing + ) { + super( + ADTaskProfileAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + ADTaskProfileRequest::new, + ADTaskProfileNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + ADTaskProfileNodeResponse.class + ); + this.adTaskManager = adTaskManager; + this.hashRing = hashRing; + } + + @Override + protected ADTaskProfileResponse newResponse( + ADTaskProfileRequest request, + List responses, + List failures + ) { + return new ADTaskProfileResponse(clusterService.getClusterName(), responses, failures); + } + + @Override + protected ADTaskProfileNodeRequest newNodeRequest(ADTaskProfileRequest request) { + return new ADTaskProfileNodeRequest(request); + } + + @Override + protected ADTaskProfileNodeResponse newNodeResponse(StreamInput in) throws IOException { + return new ADTaskProfileNodeResponse(in); + } + + @Override + protected ADTaskProfileNodeResponse nodeOperation(ADTaskProfileNodeRequest request) { + String remoteNodeId = request.getParentTask().getNodeId(); + Version remoteAdVersion = hashRing.getAdVersion(remoteNodeId); + ADTaskProfile adTaskProfile = adTaskManager.getLocalADTaskProfilesByDetectorId(request.getId()); + return new ADTaskProfileNodeResponse(clusterService.localNode(), adTaskProfile, remoteAdVersion); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobAction.java-e b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobAction.java-e new file mode 100644 index 000000000..b11283181 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobAction.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; + +public class AnomalyDetectorJobAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/jobmanagement"; + public static final AnomalyDetectorJobAction INSTANCE = new AnomalyDetectorJobAction(); + + private AnomalyDetectorJobAction() { + super(NAME, AnomalyDetectorJobResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobRequest.java b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobRequest.java index f25914365..3a62315a6 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobRequest.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobRequest.java @@ -15,8 +15,8 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.timeseries.model.DateRange; public class AnomalyDetectorJobRequest extends ActionRequest { diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobRequest.java-e b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobRequest.java-e new file mode 100644 index 000000000..3a62315a6 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobRequest.java-e @@ -0,0 +1,123 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.timeseries.model.DateRange; + +public class AnomalyDetectorJobRequest extends ActionRequest { + + private String detectorID; + private DateRange detectionDateRange; + private boolean historical; + private long seqNo; + private long primaryTerm; + private String rawPath; + + public AnomalyDetectorJobRequest(StreamInput in) throws IOException { + super(in); + detectorID = in.readString(); + seqNo = in.readLong(); + primaryTerm = in.readLong(); + rawPath = in.readString(); + if (in.readBoolean()) { + detectionDateRange = new DateRange(in); + } + historical = in.readBoolean(); + } + + public AnomalyDetectorJobRequest(String detectorID, long seqNo, long primaryTerm, String rawPath) { + this(detectorID, null, false, seqNo, primaryTerm, rawPath); + } + + /** + * Constructor function. + * + * The detectionDateRange and historical boolean can be passed in individually. + * The historical flag is for stopping detector, the detectionDateRange is for + * starting detector. It's ok if historical is true but detectionDateRange is + * null. + * + * @param detectorID detector identifier + * @param detectionDateRange detection date range + * @param historical historical analysis or not + * @param seqNo seq no + * @param primaryTerm primary term + * @param rawPath raw request path + */ + public AnomalyDetectorJobRequest( + String detectorID, + DateRange detectionDateRange, + boolean historical, + long seqNo, + long primaryTerm, + String rawPath + ) { + super(); + this.detectorID = detectorID; + this.detectionDateRange = detectionDateRange; + this.historical = historical; + this.seqNo = seqNo; + this.primaryTerm = primaryTerm; + this.rawPath = rawPath; + } + + public String getDetectorID() { + return detectorID; + } + + public DateRange getDetectionDateRange() { + return detectionDateRange; + } + + public long getSeqNo() { + return seqNo; + } + + public long getPrimaryTerm() { + return primaryTerm; + } + + public String getRawPath() { + return rawPath; + } + + public boolean isHistorical() { + return historical; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(detectorID); + out.writeLong(seqNo); + out.writeLong(primaryTerm); + out.writeString(rawPath); + if (detectionDateRange != null) { + out.writeBoolean(true); + detectionDateRange.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeBoolean(historical); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobResponse.java b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobResponse.java index b2db75d5d..157d50000 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobResponse.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobResponse.java @@ -14,11 +14,11 @@ import java.io.IOException; import org.opensearch.action.ActionResponse; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.rest.RestStatus; import org.opensearch.timeseries.util.RestHandlerUtils; public class AnomalyDetectorJobResponse extends ActionResponse implements ToXContentObject { diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobResponse.java-e b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobResponse.java-e new file mode 100644 index 000000000..f5475fb2d --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobResponse.java-e @@ -0,0 +1,71 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.timeseries.util.RestHandlerUtils; + +public class AnomalyDetectorJobResponse extends ActionResponse implements ToXContentObject { + private final String id; + private final long version; + private final long seqNo; + private final long primaryTerm; + private final RestStatus restStatus; + + public AnomalyDetectorJobResponse(StreamInput in) throws IOException { + super(in); + id = in.readString(); + version = in.readLong(); + seqNo = in.readLong(); + primaryTerm = in.readLong(); + restStatus = in.readEnum(RestStatus.class); + } + + public AnomalyDetectorJobResponse(String id, long version, long seqNo, long primaryTerm, RestStatus restStatus) { + this.id = id; + this.version = version; + this.seqNo = seqNo; + this.primaryTerm = primaryTerm; + this.restStatus = restStatus; + } + + public String getId() { + return id; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + out.writeLong(version); + out.writeLong(seqNo); + out.writeLong(primaryTerm); + out.writeEnum(restStatus); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder + .startObject() + .field(RestHandlerUtils._ID, id) + .field(RestHandlerUtils._VERSION, version) + .field(RestHandlerUtils._SEQ_NO, seqNo) + .field(RestHandlerUtils._PRIMARY_TERM, primaryTerm) + .endObject(); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportAction.java-e new file mode 100644 index 000000000..1f86cefbb --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportAction.java-e @@ -0,0 +1,154 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_START_DETECTOR; +import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_STOP_DETECTOR; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.REQUEST_TIMEOUT; +import static org.opensearch.timeseries.util.ParseUtils.getUserContext; +import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.ExecuteADResultResponseRecorder; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.rest.handler.IndexAnomalyDetectorJobActionHandler; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +public class AnomalyDetectorJobTransportAction extends HandledTransportAction { + private final Logger logger = LogManager.getLogger(AnomalyDetectorJobTransportAction.class); + + private final Client client; + private final ClusterService clusterService; + private final Settings settings; + private final ADIndexManagement anomalyDetectionIndices; + private final NamedXContentRegistry xContentRegistry; + private volatile Boolean filterByEnabled; + private final ADTaskManager adTaskManager; + private final TransportService transportService; + private final ExecuteADResultResponseRecorder recorder; + + @Inject + public AnomalyDetectorJobTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ClusterService clusterService, + Settings settings, + ADIndexManagement anomalyDetectionIndices, + NamedXContentRegistry xContentRegistry, + ADTaskManager adTaskManager, + ExecuteADResultResponseRecorder recorder + ) { + super(AnomalyDetectorJobAction.NAME, transportService, actionFilters, AnomalyDetectorJobRequest::new); + this.transportService = transportService; + this.client = client; + this.clusterService = clusterService; + this.settings = settings; + this.anomalyDetectionIndices = anomalyDetectionIndices; + this.xContentRegistry = xContentRegistry; + this.adTaskManager = adTaskManager; + filterByEnabled = FILTER_BY_BACKEND_ROLES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); + this.recorder = recorder; + } + + @Override + protected void doExecute(Task task, AnomalyDetectorJobRequest request, ActionListener actionListener) { + String detectorId = request.getDetectorID(); + DateRange detectionDateRange = request.getDetectionDateRange(); + boolean historical = request.isHistorical(); + long seqNo = request.getSeqNo(); + long primaryTerm = request.getPrimaryTerm(); + String rawPath = request.getRawPath(); + TimeValue requestTimeout = REQUEST_TIMEOUT.get(settings); + String errorMessage = rawPath.endsWith(RestHandlerUtils.START_JOB) ? FAIL_TO_START_DETECTOR : FAIL_TO_STOP_DETECTOR; + ActionListener listener = wrapRestActionListener(actionListener, errorMessage); + + // By the time request reaches here, the user permissions are validated by Security plugin. + User user = getUserContext(client); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + resolveUserAndExecute( + user, + detectorId, + filterByEnabled, + listener, + (anomalyDetector) -> executeDetector( + listener, + detectorId, + detectionDateRange, + historical, + seqNo, + primaryTerm, + rawPath, + requestTimeout, + user, + context + ), + client, + clusterService, + xContentRegistry + ); + } catch (Exception e) { + logger.error(e); + listener.onFailure(e); + } + } + + private void executeDetector( + ActionListener listener, + String detectorId, + DateRange detectionDateRange, + boolean historical, + long seqNo, + long primaryTerm, + String rawPath, + TimeValue requestTimeout, + User user, + ThreadContext.StoredContext context + ) { + IndexAnomalyDetectorJobActionHandler handler = new IndexAnomalyDetectorJobActionHandler( + client, + anomalyDetectionIndices, + detectorId, + seqNo, + primaryTerm, + requestTimeout, + xContentRegistry, + transportService, + adTaskManager, + recorder + ); + if (rawPath.endsWith(RestHandlerUtils.START_JOB)) { + adTaskManager.startDetector(detectorId, detectionDateRange, handler, user, transportService, context, listener); + } else if (rawPath.endsWith(RestHandlerUtils.STOP_JOB)) { + adTaskManager.stopDetector(detectorId, historical, handler, user, transportService, listener); + } + } +} diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultAction.java-e b/src/main/java/org/opensearch/ad/transport/AnomalyResultAction.java-e new file mode 100644 index 000000000..d61bd5822 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultAction.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; + +public class AnomalyResultAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/run"; + public static final AnomalyResultAction INSTANCE = new AnomalyResultAction(); + + private AnomalyResultAction() { + super(NAME, AnomalyResultResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultRequest.java b/src/main/java/org/opensearch/ad/transport/AnomalyResultRequest.java index b9efc10b3..e6f788aeb 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyResultRequest.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultRequest.java @@ -22,11 +22,11 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.common.io.stream.InputStreamStreamInput; -import org.opensearch.common.io.stream.OutputStreamStreamOutput; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.constant.CommonMessages; diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultRequest.java-e b/src/main/java/org/opensearch/ad/transport/AnomalyResultRequest.java-e new file mode 100644 index 000000000..0108f5a12 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultRequest.java-e @@ -0,0 +1,114 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Locale; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; + +public class AnomalyResultRequest extends ActionRequest implements ToXContentObject { + private String adID; + // time range start and end. Unit: epoch milliseconds + private long start; + private long end; + + public AnomalyResultRequest(StreamInput in) throws IOException { + super(in); + adID = in.readString(); + start = in.readLong(); + end = in.readLong(); + } + + public AnomalyResultRequest(String adID, long start, long end) { + super(); + this.adID = adID; + this.start = start; + this.end = end; + } + + public long getStart() { + return start; + } + + public long getEnd() { + return end; + } + + public String getAdID() { + return adID; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(adID); + out.writeLong(start); + out.writeLong(end); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (Strings.isEmpty(adID)) { + validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + } + if (start <= 0 || end <= 0 || start > end) { + validationException = addValidationError( + String.format(Locale.ROOT, "%s: start %d, end %d", CommonMessages.INVALID_TIMESTAMP_ERR_MSG, start, end), + validationException + ); + } + return validationException; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ADCommonName.ID_JSON_KEY, adID); + builder.field(CommonName.START_JSON_KEY, start); + builder.field(CommonName.END_JSON_KEY, end); + builder.endObject(); + return builder; + } + + public static AnomalyResultRequest fromActionRequest(final ActionRequest actionRequest) { + if (actionRequest instanceof AnomalyResultRequest) { + return (AnomalyResultRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new AnomalyResultRequest(input); + } + } catch (IOException e) { + throw new IllegalArgumentException("failed to parse ActionRequest into AnomalyResultRequest", e); + } + } +} diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultResponse.java b/src/main/java/org/opensearch/ad/transport/AnomalyResultResponse.java index fb23a7c40..6f65fdb6d 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyResultResponse.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultResponse.java @@ -22,11 +22,11 @@ import org.opensearch.action.ActionResponse; import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.common.io.stream.InputStreamStreamInput; -import org.opensearch.common.io.stream.OutputStreamStreamOutput; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.model.FeatureData; diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultResponse.java-e b/src/main/java/org/opensearch/ad/transport/AnomalyResultResponse.java-e new file mode 100644 index 000000000..5815f5430 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultResponse.java-e @@ -0,0 +1,376 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import org.opensearch.action.ActionResponse; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.model.FeatureData; + +public class AnomalyResultResponse extends ActionResponse implements ToXContentObject { + public static final String ANOMALY_GRADE_JSON_KEY = "anomalyGrade"; + public static final String CONFIDENCE_JSON_KEY = "confidence"; + public static final String ANOMALY_SCORE_JSON_KEY = "anomalyScore"; + public static final String ERROR_JSON_KEY = "error"; + public static final String FEATURES_JSON_KEY = "features"; + public static final String FEATURE_VALUE_JSON_KEY = "value"; + public static final String RCF_TOTAL_UPDATES_JSON_KEY = "rcfTotalUpdates"; + public static final String DETECTOR_INTERVAL_IN_MINUTES_JSON_KEY = "detectorIntervalInMinutes"; + public static final String RELATIVE_INDEX_FIELD_JSON_KEY = "relativeIndex"; + public static final String RELEVANT_ATTRIBUTION_FIELD_JSON_KEY = "relevantAttribution"; + public static final String PAST_VALUES_FIELD_JSON_KEY = "pastValues"; + public static final String EXPECTED_VAL_LIST_FIELD_JSON_KEY = "expectedValuesList"; + public static final String LIKELIHOOD_FIELD_JSON_KEY = "likelihoodOfValues"; + public static final String THRESHOLD_FIELD_JSON_KEY = "threshold"; + + private Double anomalyGrade; + private Double confidence; + private Double anomalyScore; + private String error; + private List features; + private Long rcfTotalUpdates; + private Long detectorIntervalInMinutes; + private Boolean isHCDetector; + private Integer relativeIndex; + private double[] relevantAttribution; + private double[] pastValues; + private double[][] expectedValuesList; + private double[] likelihoodOfValues; + private Double threshold; + + // used when returning an error/exception or empty result + public AnomalyResultResponse( + List features, + String error, + Long rcfTotalUpdates, + Long detectorIntervalInMinutes, + Boolean isHCDetector + ) { + this( + Double.NaN, + Double.NaN, + Double.NaN, + features, + error, + rcfTotalUpdates, + detectorIntervalInMinutes, + isHCDetector, + null, + null, + null, + null, + null, + Double.NaN + ); + } + + public AnomalyResultResponse( + Double anomalyGrade, + Double confidence, + Double anomalyScore, + List features, + String error, + Long rcfTotalUpdates, + Long detectorIntervalInMinutes, + Boolean isHCDetector, + Integer relativeIndex, + double[] currentTimeAttribution, + double[] pastValues, + double[][] expectedValuesList, + double[] likelihoodOfValues, + Double threshold + ) { + this.anomalyGrade = anomalyGrade; + this.confidence = confidence; + this.anomalyScore = anomalyScore; + this.features = features; + this.error = error; + this.rcfTotalUpdates = rcfTotalUpdates; + this.detectorIntervalInMinutes = detectorIntervalInMinutes; + this.isHCDetector = isHCDetector; + this.relativeIndex = relativeIndex; + this.relevantAttribution = currentTimeAttribution; + this.pastValues = pastValues; + this.expectedValuesList = expectedValuesList; + this.likelihoodOfValues = likelihoodOfValues; + this.threshold = threshold; + } + + public AnomalyResultResponse(StreamInput in) throws IOException { + super(in); + anomalyGrade = in.readDouble(); + confidence = in.readDouble(); + anomalyScore = in.readDouble(); + int size = in.readVInt(); + features = new ArrayList(); + for (int i = 0; i < size; i++) { + features.add(new FeatureData(in)); + } + error = in.readOptionalString(); + // new field added since AD 1.1 + // Only send AnomalyResultRequest to local node, no need to change this part for BWC + rcfTotalUpdates = in.readOptionalLong(); + detectorIntervalInMinutes = in.readOptionalLong(); + isHCDetector = in.readOptionalBoolean(); + + this.relativeIndex = in.readOptionalInt(); + + // input.readOptionalArray(i -> i.readDouble(), double[]::new) results in + // compiler error as readOptionalArray does not work for primitive array. + // use readDoubleArray and readBoolean instead + if (in.readBoolean()) { + this.relevantAttribution = in.readDoubleArray(); + } else { + this.relevantAttribution = null; + } + + if (in.readBoolean()) { + this.pastValues = in.readDoubleArray(); + } else { + this.pastValues = null; + } + + if (in.readBoolean()) { + int numberofExpectedVals = in.readVInt(); + this.expectedValuesList = new double[numberofExpectedVals][]; + for (int i = 0; i < numberofExpectedVals; i++) { + expectedValuesList[i] = in.readDoubleArray(); + } + } else { + this.expectedValuesList = null; + } + + if (in.readBoolean()) { + this.likelihoodOfValues = in.readDoubleArray(); + } else { + this.likelihoodOfValues = null; + } + + this.threshold = in.readOptionalDouble(); + } + + public double getAnomalyGrade() { + return anomalyGrade; + } + + public List getFeatures() { + return features; + } + + public double getConfidence() { + return confidence; + } + + public double getAnomalyScore() { + return anomalyScore; + } + + public String getError() { + return error; + } + + public Long getRcfTotalUpdates() { + return rcfTotalUpdates; + } + + public Long getIntervalInMinutes() { + return detectorIntervalInMinutes; + } + + public Boolean isHCDetector() { + return isHCDetector; + } + + public Integer getRelativeIndex() { + return relativeIndex; + } + + public double[] getCurrentTimeAttribution() { + return relevantAttribution; + } + + public double[] getOldValues() { + return pastValues; + } + + public double[][] getExpectedValuesList() { + return expectedValuesList; + } + + public double[] getLikelihoodOfValues() { + return likelihoodOfValues; + } + + public Double getThreshold() { + return threshold; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(anomalyGrade); + out.writeDouble(confidence); + out.writeDouble(anomalyScore); + out.writeVInt(features.size()); + for (FeatureData feature : features) { + feature.writeTo(out); + } + out.writeOptionalString(error); + out.writeOptionalLong(rcfTotalUpdates); + out.writeOptionalLong(detectorIntervalInMinutes); + out.writeOptionalBoolean(isHCDetector); + + out.writeOptionalInt(relativeIndex); + + // writeOptionalArray does not work for primitive array. Use WriteDoubleArray + // instead. + if (relevantAttribution != null) { + out.writeBoolean(true); + out.writeDoubleArray(relevantAttribution); + } else { + out.writeBoolean(false); + } + + if (pastValues != null) { + out.writeBoolean(true); + out.writeDoubleArray(pastValues); + } else { + out.writeBoolean(false); + } + + if (expectedValuesList != null) { + out.writeBoolean(true); + int numberofExpectedVals = expectedValuesList.length; + out.writeVInt(expectedValuesList.length); + for (int i = 0; i < numberofExpectedVals; i++) { + out.writeDoubleArray(expectedValuesList[i]); + } + } else { + out.writeBoolean(false); + } + + if (likelihoodOfValues != null) { + out.writeBoolean(true); + out.writeDoubleArray(likelihoodOfValues); + } else { + out.writeBoolean(false); + } + + out.writeOptionalDouble(threshold); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ANOMALY_GRADE_JSON_KEY, anomalyGrade); + builder.field(CONFIDENCE_JSON_KEY, confidence); + builder.field(ANOMALY_SCORE_JSON_KEY, anomalyScore); + builder.field(ERROR_JSON_KEY, error); + builder.startArray(FEATURES_JSON_KEY); + for (FeatureData feature : features) { + feature.toXContent(builder, params); + } + builder.endArray(); + builder.field(RCF_TOTAL_UPDATES_JSON_KEY, rcfTotalUpdates); + builder.field(DETECTOR_INTERVAL_IN_MINUTES_JSON_KEY, detectorIntervalInMinutes); + builder.field(RELATIVE_INDEX_FIELD_JSON_KEY, relativeIndex); + builder.field(RELEVANT_ATTRIBUTION_FIELD_JSON_KEY, relevantAttribution); + builder.field(PAST_VALUES_FIELD_JSON_KEY, pastValues); + builder.field(EXPECTED_VAL_LIST_FIELD_JSON_KEY, expectedValuesList); + builder.field(LIKELIHOOD_FIELD_JSON_KEY, likelihoodOfValues); + builder.field(THRESHOLD_FIELD_JSON_KEY, threshold); + builder.endObject(); + return builder; + } + + public static AnomalyResultResponse fromActionResponse(final ActionResponse actionResponse) { + if (actionResponse instanceof AnomalyResultResponse) { + return (AnomalyResultResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (InputStreamStreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new AnomalyResultResponse(input); + } + } catch (IOException e) { + throw new IllegalArgumentException("failed to parse ActionResponse into AnomalyResultResponse", e); + } + } + + /** + * + * Convert AnomalyResultResponse to AnomalyResult + * + * @param detectorId Detector Id + * @param dataStartInstant data start time + * @param dataEndInstant data end time + * @param executionStartInstant execution start time + * @param executionEndInstant execution end time + * @param schemaVersion Schema version + * @param user Detector author + * @param error Error + * @return converted AnomalyResult + */ + public AnomalyResult toAnomalyResult( + String detectorId, + Instant dataStartInstant, + Instant dataEndInstant, + Instant executionStartInstant, + Instant executionEndInstant, + Integer schemaVersion, + User user, + String error + ) { + // Detector interval in milliseconds + long detectorIntervalMilli = Duration.between(dataStartInstant, dataEndInstant).toMillis(); + return AnomalyResult + .fromRawTRCFResult( + detectorId, + detectorIntervalMilli, + null, // real time results have no task id + anomalyScore, + anomalyGrade, + confidence, + features, + dataStartInstant, + dataEndInstant, + executionStartInstant, + executionEndInstant, + error, + Optional.empty(), + user, + schemaVersion, + null, // single-stream real-time has no model id + relevantAttribution, + relativeIndex, + pastValues, + expectedValuesList, + likelihoodOfValues, + threshold + ); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java index e0029edf2..d7454bcda 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java @@ -43,7 +43,6 @@ import org.opensearch.action.support.IndicesOptions; import org.opensearch.action.support.ThreadedActionListener; import org.opensearch.action.support.master.AcknowledgedResponse; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.ad.cluster.HashRing; @@ -70,17 +69,18 @@ import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.common.io.stream.NotSerializableExceptionWrapper; import org.opensearch.common.lease.Releasable; import org.opensearch.common.settings.Settings; import org.opensearch.common.transport.NetworkExceptionHelper; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.index.IndexNotFoundException; import org.opensearch.node.NodeClosedException; -import org.opensearch.rest.RestStatus; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.common.exception.ClientException; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.InternalFailure; @@ -311,7 +311,7 @@ public void onResponse(CompositeRetriever.Page entityFeatures) { } if (entityFeatures != null && false == entityFeatures.isEmpty()) { // wrap expensive operation inside ad threadpool - threadPool.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME).execute(() -> { + threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME).execute(() -> { try { Set>> node2Entities = entityFeatures @@ -517,6 +517,7 @@ private void executeAnomalyDetection( DiscoveryNode rcfNode = asRCFNode.get(); + // we have already returned listener inside shouldStart method if (!shouldStart(listener, adID, anomalyDetector, rcfNode.getId(), rcfModelID)) { return; } @@ -1006,7 +1007,13 @@ private void coldStart(AnomalyDetector detector) { .trainModel( detector, dataPoints, - new ThreadedActionListener<>(LOG, threadPool, AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, trainModelListener, false) + new ThreadedActionListener<>( + LOG, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + trainModelListener, + false + ) ); } else { stateManager.setException(detectorId, new EndRunException(detectorId, "Cannot get training data", false)); @@ -1026,7 +1033,7 @@ private void coldStart(AnomalyDetector detector) { .runAfter(listener, coldStartFinishingCallback::close); threadPool - .executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME) + .executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME) .execute( () -> featureManager .getColdStartData( @@ -1034,7 +1041,7 @@ private void coldStart(AnomalyDetector detector) { new ThreadedActionListener<>( LOG, threadPool, - AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, listenerWithReleaseCallback, false ) diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java-e new file mode 100644 index 000000000..518ba0902 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java-e @@ -0,0 +1,1138 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.PAGE_SIZE; +import static org.opensearch.timeseries.constant.CommonMessages.INVALID_SEARCH_QUERY_MSG; + +import java.net.ConnectException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchTimeoutException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionListenerResponseHandler; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.search.SearchPhaseExecutionException; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.action.support.ThreadedActionListener; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.feature.CompositeRetriever; +import org.opensearch.ad.feature.CompositeRetriever.PageIterator; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.feature.SinglePointFeatures; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.SingleStreamModelIdMapper; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.stats.ADStats; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.util.ExceptionUtil; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.block.ClusterBlockLevel; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; +import org.opensearch.common.lease.Releasable; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.transport.NetworkExceptionHelper; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.node.NodeClosedException; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.common.exception.ClientException; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.InternalFailure; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.common.exception.NotSerializedExceptionName; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.transport.ActionNotFoundTransportException; +import org.opensearch.transport.ConnectTransportException; +import org.opensearch.transport.NodeNotConnectedException; +import org.opensearch.transport.ReceiveTimeoutTransportException; +import org.opensearch.transport.TransportRequestOptions; +import org.opensearch.transport.TransportService; + +public class AnomalyResultTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(AnomalyResultTransportAction.class); + static final String WAIT_FOR_THRESHOLD_ERR_MSG = "Exception in waiting for threshold result"; + static final String NODE_UNRESPONSIVE_ERR_MSG = "Model node is unresponsive. Mute node"; + static final String READ_WRITE_BLOCKED = "Cannot read/write due to global block."; + static final String INDEX_READ_BLOCKED = "Cannot read user index due to read block."; + static final String NULL_RESPONSE = "Received null response from"; + + static final String TROUBLE_QUERYING_ERR_MSG = "Having trouble querying data: "; + static final String NO_ACK_ERR = "no acknowledgements from model hosting nodes."; + + private final TransportService transportService; + private final NodeStateManager stateManager; + private final FeatureManager featureManager; + private final ModelManager modelManager; + private final HashRing hashRing; + private final TransportRequestOptions option; + private final ClusterService clusterService; + private final IndexNameExpressionResolver indexNameExpressionResolver; + private final ADStats adStats; + private final ADCircuitBreakerService adCircuitBreakerService; + private final ThreadPool threadPool; + private final Client client; + private final SecurityClientUtil clientUtil; + private final ADTaskManager adTaskManager; + + // Cache HC detector id. This is used to count HC failure stats. We can tell a detector + // is HC or not by checking if detector id exists in this field or not. Will add + // detector id to this field when start to run realtime detection and remove detector + // id once realtime detection done. + private final Set hcDetectors; + private NamedXContentRegistry xContentRegistry; + private Settings settings; + // within an interval, how many percents are used to process requests. + // 1.0 means we use all of the detection interval to process requests. + // to ensure we don't block next interval, it is better to set it less than 1.0. + private final float intervalRatioForRequest; + private int maxEntitiesPerInterval; + private int pageSize; + + @Inject + public AnomalyResultTransportAction( + ActionFilters actionFilters, + TransportService transportService, + Settings settings, + Client client, + SecurityClientUtil clientUtil, + NodeStateManager manager, + FeatureManager featureManager, + ModelManager modelManager, + HashRing hashRing, + ClusterService clusterService, + IndexNameExpressionResolver indexNameExpressionResolver, + ADCircuitBreakerService adCircuitBreakerService, + ADStats adStats, + ThreadPool threadPool, + NamedXContentRegistry xContentRegistry, + ADTaskManager adTaskManager + ) { + super(AnomalyResultAction.NAME, transportService, actionFilters, AnomalyResultRequest::new); + this.transportService = transportService; + this.settings = settings; + this.client = client; + this.clientUtil = clientUtil; + this.stateManager = manager; + this.featureManager = featureManager; + this.modelManager = modelManager; + this.hashRing = hashRing; + this.option = TransportRequestOptions + .builder() + .withType(TransportRequestOptions.Type.REG) + .withTimeout(AnomalyDetectorSettings.REQUEST_TIMEOUT.get(settings)) + .build(); + this.clusterService = clusterService; + this.indexNameExpressionResolver = indexNameExpressionResolver; + this.adCircuitBreakerService = adCircuitBreakerService; + this.adStats = adStats; + this.threadPool = threadPool; + this.hcDetectors = new HashSet<>(); + this.xContentRegistry = xContentRegistry; + this.intervalRatioForRequest = AnomalyDetectorSettings.INTERVAL_RATIO_FOR_REQUESTS; + + this.maxEntitiesPerInterval = MAX_ENTITIES_PER_QUERY.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_ENTITIES_PER_QUERY, it -> maxEntitiesPerInterval = it); + + this.pageSize = PAGE_SIZE.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(PAGE_SIZE, it -> pageSize = it); + this.adTaskManager = adTaskManager; + } + + /** + * All the exceptions thrown by AD is a subclass of AnomalyDetectionException. + * ClientException is a subclass of AnomalyDetectionException. All exception visible to + * Client is under ClientVisible. Two classes directly extends ClientException: + * - InternalFailure for "root cause unknown failure. Maybe transient." We can continue the + * detector running. + * - EndRunException for "failures that might impact the customer." The method endNow() is + * added to indicate whether the client should immediately terminate running a detector. + * + endNow() returns true for "unrecoverable issue". We want to terminate the detector run + * immediately. + * + endNow() returns false for "maybe unrecoverable issue but worth retrying a few more + * times." We want to wait for a few more times on different requests before terminating + * the detector run. + * + * AD may not be able to get an anomaly grade but can find a feature vector. Consider the + * case when the shingle is not ready. In that case, AD just put NaN as anomaly grade and + * return the feature vector. If AD cannot even find a feature vector, AD throws + * EndRunException if there is an issue or returns empty response (all the numeric fields + * are Double.NaN and feature array is empty. Do so so that customer can write painless + * script.) otherwise. + * + * + * Known causes of EndRunException with endNow returning false: + * + training data for cold start not available + * + cold start cannot succeed + * + unknown prediction error + * + memory circuit breaker tripped + * + invalid search query + * + * Known causes of EndRunException with endNow returning true: + * + a model partition's memory size reached limit + * + models' total memory size reached limit + * + Having trouble querying feature data due to + * * index does not exist + * * all features have been disabled + * + * + anomaly detector is not available + * + AD plugin is disabled + * + training data is invalid due to serious internal bug(s) + * + * Known causes of InternalFailure: + * + threshold model node is not available + * + cluster read/write is blocked + * + cold start hasn't been finished + * + fail to get all of rcf model nodes' responses + * + fail to get threshold model node's response + * + RCF/Threshold model node failing to get checkpoint to restore model before timeout + * + Detection is throttle because previous detection query is running + * + */ + @Override + protected void doExecute(Task task, ActionRequest actionRequest, ActionListener listener) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + AnomalyResultRequest request = AnomalyResultRequest.fromActionRequest(actionRequest); + String adID = request.getAdID(); + ActionListener original = listener; + listener = ActionListener.wrap(r -> { + hcDetectors.remove(adID); + original.onResponse(r); + }, e -> { + // If exception is AnomalyDetectionException and it should not be counted in stats, + // we will not count it in failure stats. + if (!(e instanceof TimeSeriesException) || ((TimeSeriesException) e).isCountedInStats()) { + adStats.getStat(StatNames.AD_EXECUTE_FAIL_COUNT.getName()).increment(); + if (hcDetectors.contains(adID)) { + adStats.getStat(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName()).increment(); + } + } + hcDetectors.remove(adID); + original.onFailure(e); + }); + + if (!ADEnabledSetting.isADEnabled()) { + throw new EndRunException(adID, ADCommonMessages.DISABLED_ERR_MSG, true).countedInStats(false); + } + + adStats.getStat(StatNames.AD_EXECUTE_REQUEST_COUNT.getName()).increment(); + + if (adCircuitBreakerService.isOpen()) { + listener.onFailure(new LimitExceededException(adID, CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, false)); + return; + } + try { + stateManager.getAnomalyDetector(adID, onGetDetector(listener, adID, request)); + } catch (Exception ex) { + handleExecuteException(ex, listener, adID); + } + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + } + + /** + * didn't use ActionListener.wrap so that I can + * 1) use this to refer to the listener inside the listener + * 2) pass parameters using constructors + * + */ + class PageListener implements ActionListener { + private PageIterator pageIterator; + private String detectorId; + private long dataStartTime; + private long dataEndTime; + + PageListener(PageIterator pageIterator, String detectorId, long dataStartTime, long dataEndTime) { + this.pageIterator = pageIterator; + this.detectorId = detectorId; + this.dataStartTime = dataStartTime; + this.dataEndTime = dataEndTime; + } + + @Override + public void onResponse(CompositeRetriever.Page entityFeatures) { + if (pageIterator.hasNext()) { + pageIterator.next(this); + } + if (entityFeatures != null && false == entityFeatures.isEmpty()) { + // wrap expensive operation inside ad threadpool + threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME).execute(() -> { + try { + + Set>> node2Entities = entityFeatures + .getResults() + .entrySet() + .stream() + .filter(e -> hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(e.getKey().toString()).isPresent()) + .collect( + Collectors + .groupingBy( + // from entity name to its node + e -> hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(e.getKey().toString()).get(), + Collectors.toMap(Entry::getKey, Entry::getValue) + ) + ) + .entrySet(); + + Iterator>> iterator = node2Entities.iterator(); + + while (iterator.hasNext()) { + Entry> entry = iterator.next(); + DiscoveryNode modelNode = entry.getKey(); + if (modelNode == null) { + iterator.remove(); + continue; + } + String modelNodeId = modelNode.getId(); + if (stateManager.isMuted(modelNodeId, detectorId)) { + LOG + .info( + String + .format(Locale.ROOT, NODE_UNRESPONSIVE_ERR_MSG + " %s for detector %s", modelNodeId, detectorId) + ); + iterator.remove(); + } + } + + final AtomicReference failure = new AtomicReference<>(); + node2Entities.stream().forEach(nodeEntity -> { + DiscoveryNode node = nodeEntity.getKey(); + transportService + .sendRequest( + node, + EntityResultAction.NAME, + new EntityResultRequest(detectorId, nodeEntity.getValue(), dataStartTime, dataEndTime), + option, + new ActionListenerResponseHandler<>( + new EntityResultListener(node.getId(), detectorId, failure), + AcknowledgedResponse::new, + ThreadPool.Names.SAME + ) + ); + }); + + } catch (Exception e) { + LOG.error("Unexpected exception", e); + handleException(e); + } + }); + } + } + + @Override + public void onFailure(Exception e) { + LOG.error("Unexpetected exception", e); + handleException(e); + } + + private void handleException(Exception e) { + Exception convertedException = convertedQueryFailureException(e, detectorId); + if (false == (convertedException instanceof TimeSeriesException)) { + Throwable cause = ExceptionsHelper.unwrapCause(convertedException); + convertedException = new InternalFailure(detectorId, cause); + } + stateManager.setException(detectorId, convertedException); + } + } + + private ActionListener> onGetDetector( + ActionListener listener, + String adID, + AnomalyResultRequest request + ) { + return ActionListener.wrap(detectorOptional -> { + if (!detectorOptional.isPresent()) { + listener.onFailure(new EndRunException(adID, "AnomalyDetector is not available.", true)); + return; + } + + AnomalyDetector anomalyDetector = detectorOptional.get(); + if (anomalyDetector.isHighCardinality()) { + hcDetectors.add(adID); + adStats.getStat(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName()).increment(); + } + + long delayMillis = Optional + .ofNullable((IntervalTimeConfiguration) anomalyDetector.getWindowDelay()) + .map(t -> t.toDuration().toMillis()) + .orElse(0L); + long dataStartTime = request.getStart() - delayMillis; + long dataEndTime = request.getEnd() - delayMillis; + + adTaskManager + .initRealtimeTaskCacheAndCleanupStaleCache( + adID, + anomalyDetector, + transportService, + ActionListener + .runAfter( + initRealtimeTaskCacheListener(adID), + () -> executeAnomalyDetection(listener, adID, request, anomalyDetector, dataStartTime, dataEndTime) + ) + ); + }, exception -> handleExecuteException(exception, listener, adID)); + } + + private ActionListener initRealtimeTaskCacheListener(String detectorId) { + return ActionListener.wrap(r -> { + if (r) { + LOG.debug("Realtime task cache initied for detector {}", detectorId); + } + }, e -> LOG.error("Failed to init realtime task cache for " + detectorId, e)); + } + + private void executeAnomalyDetection( + ActionListener listener, + String adID, + AnomalyResultRequest request, + AnomalyDetector anomalyDetector, + long dataStartTime, + long dataEndTime + ) { + // HC logic starts here + if (anomalyDetector.isHighCardinality()) { + Optional previousException = stateManager.fetchExceptionAndClear(adID); + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error(new ParameterizedMessage("Previous exception of [{}]", adID), exception); + if (exception instanceof EndRunException) { + EndRunException endRunException = (EndRunException) exception; + if (endRunException.isEndNow()) { + listener.onFailure(exception); + return; + } + } + } + + // assume request are in epoch milliseconds + long nextDetectionStartTime = request.getEnd() + (long) (anomalyDetector.getIntervalInMilliseconds() * intervalRatioForRequest); + + CompositeRetriever compositeRetriever = new CompositeRetriever( + dataStartTime, + dataEndTime, + anomalyDetector, + xContentRegistry, + client, + clientUtil, + nextDetectionStartTime, + settings, + maxEntitiesPerInterval, + pageSize, + indexNameExpressionResolver, + clusterService + ); + + PageIterator pageIterator = null; + + try { + pageIterator = compositeRetriever.iterator(); + } catch (Exception e) { + listener.onFailure(new EndRunException(anomalyDetector.getId(), CommonMessages.INVALID_SEARCH_QUERY_MSG, e, false)); + return; + } + + PageListener getEntityFeatureslistener = new PageListener(pageIterator, adID, dataStartTime, dataEndTime); + if (pageIterator.hasNext()) { + pageIterator.next(getEntityFeatureslistener); + } + + // We don't know when the pagination will not finish. To not + // block the following interval request to start, we return immediately. + // Pagination will stop itself when the time is up. + if (previousException.isPresent()) { + listener.onFailure(previousException.get()); + } else { + listener + .onResponse( + new AnomalyResultResponse(new ArrayList(), null, null, anomalyDetector.getIntervalInMinutes(), true) + ); + } + return; + } + + // HC logic ends and single entity logic starts here + // We are going to use only 1 model partition for a single stream detector. + // That's why we use 0 here. + String rcfModelID = SingleStreamModelIdMapper.getRcfModelId(adID, 0); + Optional asRCFNode = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(rcfModelID); + if (!asRCFNode.isPresent()) { + listener.onFailure(new InternalFailure(adID, "RCF model node is not available.")); + return; + } + + DiscoveryNode rcfNode = asRCFNode.get(); + + // we have already returned listener inside shouldStart method + if (!shouldStart(listener, adID, anomalyDetector, rcfNode.getId(), rcfModelID)) { + return; + } + + featureManager + .getCurrentFeatures( + anomalyDetector, + dataStartTime, + dataEndTime, + onFeatureResponseForSingleEntityDetector(adID, anomalyDetector, listener, rcfModelID, rcfNode, dataStartTime, dataEndTime) + ); + } + + // For single entity detector + private ActionListener onFeatureResponseForSingleEntityDetector( + String adID, + AnomalyDetector detector, + ActionListener listener, + String rcfModelId, + DiscoveryNode rcfNode, + long dataStartTime, + long dataEndTime + ) { + return ActionListener.wrap(featureOptional -> { + List featureInResponse = null; + if (featureOptional.getUnprocessedFeatures().isPresent()) { + featureInResponse = ParseUtils.getFeatureData(featureOptional.getUnprocessedFeatures().get(), detector); + } + + if (!featureOptional.getProcessedFeatures().isPresent()) { + Optional exception = coldStartIfNoCheckPoint(detector); + if (exception.isPresent()) { + listener.onFailure(exception.get()); + return; + } + + if (!featureOptional.getUnprocessedFeatures().isPresent()) { + // Feature not available is common when we have data holes. Respond empty response + // and don't log to avoid bloating our logs. + LOG.debug("No data in current detection window between {} and {} for {}", dataStartTime, dataEndTime, adID); + listener + .onResponse( + new AnomalyResultResponse( + new ArrayList(), + "No data in current detection window", + null, + null, + false + ) + ); + } else { + LOG.debug("Return at least current feature value between {} and {} for {}", dataStartTime, dataEndTime, adID); + listener + .onResponse( + new AnomalyResultResponse(featureInResponse, "No full shingle in current detection window", null, null, false) + ); + } + return; + } + + final AtomicReference failure = new AtomicReference(); + + LOG.info("Sending RCF request to {} for model {}", rcfNode.getId(), rcfModelId); + + RCFActionListener rcfListener = new RCFActionListener( + rcfModelId, + failure, + rcfNode.getId(), + detector, + listener, + featureInResponse, + adID + ); + + transportService + .sendRequest( + rcfNode, + RCFResultAction.NAME, + new RCFResultRequest(adID, rcfModelId, featureOptional.getProcessedFeatures().get()), + option, + new ActionListenerResponseHandler<>(rcfListener, RCFResultResponse::new) + ); + }, exception -> { handleQueryFailure(exception, listener, adID); }); + } + + private void handleQueryFailure(Exception exception, ActionListener listener, String adID) { + Exception convertedQueryFailureException = convertedQueryFailureException(exception, adID); + + if (convertedQueryFailureException instanceof EndRunException) { + // invalid feature query + listener.onFailure(convertedQueryFailureException); + } else { + handleExecuteException(convertedQueryFailureException, listener, adID); + } + } + + /** + * Convert a query related exception to EndRunException + * + * These query exception can happen during the starting phase of the OpenSearch + * process. Thus, set the stopNow parameter of these EndRunException to false + * and confirm the EndRunException is not a false positive. + * + * @param exception Exception + * @param adID detector Id + * @return the converted exception if the exception is query related + */ + private Exception convertedQueryFailureException(Exception exception, String adID) { + if (ExceptionUtil.isIndexNotAvailable(exception)) { + return new EndRunException(adID, TROUBLE_QUERYING_ERR_MSG + exception.getMessage(), false).countedInStats(false); + } else if (exception instanceof SearchPhaseExecutionException && invalidQuery((SearchPhaseExecutionException) exception)) { + // This is to catch invalid aggregation on wrong field type. For example, + // sum aggregation on text field. We should end detector run for such case. + return new EndRunException( + adID, + INVALID_SEARCH_QUERY_MSG + " " + ((SearchPhaseExecutionException) exception).getDetailedMessage(), + exception, + false + ).countedInStats(false); + } + + return exception; + } + + /** + * Verify failure of rcf or threshold models. If there is no model, trigger cold + * start. If there is an exception for the previous cold start of this detector, + * throw exception to the caller. + * + * @param failure object that may contain exceptions thrown + * @param detector detector object + * @return exception if AD job execution gets resource not found exception + * @throws Exception when the input failure is not a ResourceNotFoundException. + * List of exceptions we can throw + * 1. Exception from cold start: + * 1). InternalFailure due to + * a. OpenSearchTimeoutException thrown by putModelCheckpoint during cold start + * 2). EndRunException with endNow equal to false + * a. training data not available + * b. cold start cannot succeed + * c. invalid training data + * 3) EndRunException with endNow equal to true + * a. invalid search query + * 2. LimitExceededException from one of RCF model node when the total size of the models + * is more than X% of heap memory. + * 3. InternalFailure wrapping OpenSearchTimeoutException inside caused by + * RCF/Threshold model node failing to get checkpoint to restore model before timeout. + */ + private Exception coldStartIfNoModel(AtomicReference failure, AnomalyDetector detector) throws Exception { + Exception exp = failure.get(); + if (exp == null) { + return null; + } + + // return exceptions like LimitExceededException to caller + if (!(exp instanceof ResourceNotFoundException)) { + return exp; + } + + // fetch previous cold start exception + String adID = detector.getId(); + final Optional previousException = stateManager.fetchExceptionAndClear(adID); + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error("Previous exception of {}: {}", () -> adID, () -> exception); + if (exception instanceof EndRunException && ((EndRunException) exception).isEndNow()) { + return exception; + } + } + LOG.info("Trigger cold start for {}", detector.getId()); + coldStart(detector); + return previousException.orElse(new InternalFailure(adID, ADCommonMessages.NO_MODEL_ERR_MSG)); + } + + private void findException(Throwable cause, String adID, AtomicReference failure, String nodeId) { + if (cause == null) { + LOG.error(new ParameterizedMessage("Null input exception")); + return; + } + if (cause instanceof Error) { + // we cannot do anything with Error. + LOG.error(new ParameterizedMessage("Error during prediction for {}: ", adID), cause); + return; + } + + Exception causeException = (Exception) cause; + + if (causeException instanceof TimeSeriesException) { + failure.set(causeException); + } else if (causeException instanceof NotSerializableExceptionWrapper) { + // we only expect this happens on AD exceptions + Optional actualException = NotSerializedExceptionName + .convertWrappedTimeSeriesException((NotSerializableExceptionWrapper) causeException, adID); + if (actualException.isPresent()) { + TimeSeriesException adException = actualException.get(); + failure.set(adException); + if (adException instanceof ResourceNotFoundException) { + // During a rolling upgrade or blue/green deployment, ResourceNotFoundException might be caused by old node using RCF + // 1.0 + // cannot recognize new checkpoint produced by the coordinating node using compact RCF. Add pressure to mute the node + // after consecutive failures. + stateManager.addPressure(nodeId, adID); + } + } else { + // some unexpected bugs occur while predicting anomaly + failure.set(new EndRunException(adID, CommonMessages.BUG_RESPONSE, causeException, false)); + } + } else if (causeException instanceof IndexNotFoundException + && causeException.getMessage().contains(ADCommonName.CHECKPOINT_INDEX_NAME)) { + // checkpoint index does not exist + // ResourceNotFoundException will trigger cold start later + failure.set(new ResourceNotFoundException(adID, causeException.getMessage())); + } else if (causeException instanceof OpenSearchTimeoutException) { + // we can have OpenSearchTimeoutException when a node tries to load RCF or + // threshold model + failure.set(new InternalFailure(adID, causeException)); + } else if (causeException instanceof IllegalArgumentException) { + // we can have IllegalArgumentException when a model is corrupted + failure.set(new InternalFailure(adID, causeException)); + } else { + // some unexpected bug occurred or cluster is unstable (e.g., ClusterBlockException) or index is red (e.g. + // NoShardAvailableActionException) while predicting anomaly + failure.set(new EndRunException(adID, CommonMessages.BUG_RESPONSE, causeException, false)); + } + } + + void handleExecuteException(Exception ex, ActionListener listener, String adID) { + if (ex instanceof ClientException) { + listener.onFailure(ex); + } else if (ex instanceof TimeSeriesException) { + listener.onFailure(new InternalFailure((TimeSeriesException) ex)); + } else { + Throwable cause = ExceptionsHelper.unwrapCause(ex); + listener.onFailure(new InternalFailure(adID, cause)); + } + } + + private boolean invalidQuery(SearchPhaseExecutionException ex) { + // If all shards return bad request and failure cause is IllegalArgumentException, we + // consider the feature query is invalid and will not count the error in failure stats. + for (ShardSearchFailure failure : ex.shardFailures()) { + if (RestStatus.BAD_REQUEST != failure.status() || !(failure.getCause() instanceof IllegalArgumentException)) { + return false; + } + } + return true; + } + + // For single entity detector + class RCFActionListener implements ActionListener { + private String modelID; + private AtomicReference failure; + private String rcfNodeID; + private AnomalyDetector detector; + private ActionListener listener; + private List featureInResponse; + private final String adID; + + RCFActionListener( + String modelID, + AtomicReference failure, + String rcfNodeID, + AnomalyDetector detector, + ActionListener listener, + List features, + String adID + ) { + this.modelID = modelID; + this.failure = failure; + this.rcfNodeID = rcfNodeID; + this.detector = detector; + this.listener = listener; + this.featureInResponse = features; + this.adID = adID; + } + + @Override + public void onResponse(RCFResultResponse response) { + try { + stateManager.resetBackpressureCounter(rcfNodeID, adID); + if (response != null) { + listener + .onResponse( + new AnomalyResultResponse( + response.getAnomalyGrade(), + response.getConfidence(), + response.getRCFScore(), + featureInResponse, + null, + response.getTotalUpdates(), + detector.getIntervalInMinutes(), + false, + response.getRelativeIndex(), + response.getAttribution(), + response.getPastValues(), + response.getExpectedValuesList(), + response.getLikelihoodOfValues(), + response.getThreshold() + ) + ); + } else { + LOG.warn(NULL_RESPONSE + " {} for {}", modelID, rcfNodeID); + listener.onFailure(new InternalFailure(adID, ADCommonMessages.NO_MODEL_ERR_MSG)); + } + } catch (Exception ex) { + LOG.error(new ParameterizedMessage("Unexpected exception for [{}]", adID), ex); + handleExecuteException(ex, listener, adID); + } + } + + @Override + public void onFailure(Exception e) { + try { + handlePredictionFailure(e, adID, rcfNodeID, failure); + Exception exception = coldStartIfNoModel(failure, detector); + if (exception != null) { + listener.onFailure(exception); + } else { + listener.onFailure(new InternalFailure(adID, "Node connection problem or unexpected exception")); + } + } catch (Exception ex) { + LOG.error(new ParameterizedMessage("Unexpected exception for [{}]", adID), ex); + handleExecuteException(ex, listener, adID); + } + } + } + + /** + * Handle a prediction failure. Possibly (i.e., we don't always need to do that) + * convert the exception to a form that AD can recognize and handle and sets the + * input failure reference to the converted exception. + * + * @param e prediction exception + * @param adID Detector Id + * @param nodeID Node Id + * @param failure Parameter to receive the possibly converted function for the + * caller to deal with + */ + private void handlePredictionFailure(Exception e, String adID, String nodeID, AtomicReference failure) { + LOG.error(new ParameterizedMessage("Received an error from node {} while doing model inference for {}", nodeID, adID), e); + if (e == null) { + return; + } + Throwable cause = ExceptionsHelper.unwrapCause(e); + if (hasConnectionIssue(cause)) { + handleConnectionException(nodeID, adID); + } else { + findException(cause, adID, failure, nodeID); + } + } + + /** + * Check if the input exception indicates connection issues. + * During blue-green deployment, we may see ActionNotFoundTransportException. + * Count that as connection issue and isolate that node if it continues to happen. + * + * @param e exception + * @return true if we get disconnected from the node or the node is not in the + * right state (being closed) or transport request times out (sent from TimeoutHandler.run) + */ + private boolean hasConnectionIssue(Throwable e) { + return e instanceof ConnectTransportException + || e instanceof NodeClosedException + || e instanceof ReceiveTimeoutTransportException + || e instanceof NodeNotConnectedException + || e instanceof ConnectException + || NetworkExceptionHelper.isCloseConnectionException(e) + || e instanceof ActionNotFoundTransportException; + } + + private void handleConnectionException(String node, String detectorId) { + final DiscoveryNodes nodes = clusterService.state().nodes(); + if (!nodes.nodeExists(node)) { + hashRing.buildCirclesForRealtimeAD(); + return; + } + // rebuilding is not done or node is unresponsive + stateManager.addPressure(node, detectorId); + } + + /** + * Since we need to read from customer index and write to anomaly result index, + * we need to make sure we can read and write. + * + * @param state Cluster state + * @return whether we have global block or not + */ + private boolean checkGlobalBlock(ClusterState state) { + return state.blocks().globalBlockedException(ClusterBlockLevel.READ) != null + || state.blocks().globalBlockedException(ClusterBlockLevel.WRITE) != null; + } + + /** + * Similar to checkGlobalBlock, we check block on the indices level. + * + * @param state Cluster state + * @param level block level + * @param indices the indices on which to check block + * @return whether any of the index has block on the level. + */ + private boolean checkIndicesBlocked(ClusterState state, ClusterBlockLevel level, String... indices) { + // the original index might be an index expression with wildcards like "log*", + // so we need to expand the expression to concrete index name + String[] concreteIndices = indexNameExpressionResolver.concreteIndexNames(state, IndicesOptions.lenientExpandOpen(), indices); + + return state.blocks().indicesBlockedException(level, concreteIndices) != null; + } + + /** + * Check if we should start anomaly prediction. + * + * @param listener listener to respond back to AnomalyResultRequest. + * @param adID detector ID + * @param detector detector instance corresponds to adID + * @param rcfNodeId the rcf model hosting node ID for adID + * @param rcfModelID the rcf model ID for adID + * @return if we can start anomaly prediction. + */ + private boolean shouldStart( + ActionListener listener, + String adID, + AnomalyDetector detector, + String rcfNodeId, + String rcfModelID + ) { + ClusterState state = clusterService.state(); + if (checkGlobalBlock(state)) { + listener.onFailure(new InternalFailure(adID, READ_WRITE_BLOCKED)); + return false; + } + + if (stateManager.isMuted(rcfNodeId, adID)) { + listener + .onFailure( + new InternalFailure( + adID, + String.format(Locale.ROOT, NODE_UNRESPONSIVE_ERR_MSG + " %s for rcf model %s", rcfNodeId, rcfModelID) + ) + ); + return false; + } + + if (checkIndicesBlocked(state, ClusterBlockLevel.READ, detector.getIndices().toArray(new String[0]))) { + listener.onFailure(new InternalFailure(adID, INDEX_READ_BLOCKED)); + return false; + } + + return true; + } + + private void coldStart(AnomalyDetector detector) { + String detectorId = detector.getId(); + + // If last cold start is not finished, we don't trigger another one + if (stateManager.isColdStartRunning(detectorId)) { + return; + } + + final Releasable coldStartFinishingCallback = stateManager.markColdStartRunning(detectorId); + + ActionListener> listener = ActionListener.wrap(trainingData -> { + if (trainingData.isPresent()) { + double[][] dataPoints = trainingData.get(); + + ActionListener trainModelListener = ActionListener + .wrap(res -> { LOG.info("Succeeded in training {}", detectorId); }, exception -> { + if (exception instanceof TimeSeriesException) { + // e.g., partitioned model exceeds memory limit + stateManager.setException(detectorId, exception); + } else if (exception instanceof IllegalArgumentException) { + // IllegalArgumentException due to invalid training data + stateManager + .setException(detectorId, new EndRunException(detectorId, "Invalid training data", exception, false)); + } else if (exception instanceof OpenSearchTimeoutException) { + stateManager + .setException( + detectorId, + new InternalFailure(detectorId, "Time out while indexing cold start checkpoint", exception) + ); + } else { + stateManager + .setException(detectorId, new EndRunException(detectorId, "Error while training model", exception, false)); + } + }); + + modelManager + .trainModel( + detector, + dataPoints, + new ThreadedActionListener<>( + LOG, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + trainModelListener, + false + ) + ); + } else { + stateManager.setException(detectorId, new EndRunException(detectorId, "Cannot get training data", false)); + } + }, exception -> { + if (exception instanceof OpenSearchTimeoutException) { + stateManager.setException(detectorId, new InternalFailure(detectorId, "Time out while getting training data", exception)); + } else if (exception instanceof TimeSeriesException) { + // e.g., Invalid search query + stateManager.setException(detectorId, exception); + } else { + stateManager.setException(detectorId, new EndRunException(detectorId, "Error while cold start", exception, false)); + } + }); + + final ActionListener> listenerWithReleaseCallback = ActionListener + .runAfter(listener, coldStartFinishingCallback::close); + + threadPool + .executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME) + .execute( + () -> featureManager + .getColdStartData( + detector, + new ThreadedActionListener<>( + LOG, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + listenerWithReleaseCallback, + false + ) + ) + ); + } + + /** + * Check if checkpoint for an detector exists or not. If not and previous + * run is not EndRunException whose endNow is true, trigger cold start. + * @param detector detector object + * @return previous cold start exception + */ + private Optional coldStartIfNoCheckPoint(AnomalyDetector detector) { + String detectorId = detector.getId(); + + Optional previousException = stateManager.fetchExceptionAndClear(detectorId); + + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error(new ParameterizedMessage("Previous exception of {}:", detectorId), exception); + if (exception instanceof EndRunException && ((EndRunException) exception).isEndNow()) { + return previousException; + } + } + + stateManager.getDetectorCheckpoint(detectorId, ActionListener.wrap(checkpointExists -> { + if (!checkpointExists) { + LOG.info("Trigger cold start for {}", detectorId); + coldStart(detector); + } + }, exception -> { + Throwable cause = ExceptionsHelper.unwrapCause(exception); + if (cause instanceof IndexNotFoundException) { + LOG.info("Trigger cold start for {}", detectorId); + coldStart(detector); + } else { + String errorMsg = String.format(Locale.ROOT, "Fail to get checkpoint state for %s", detectorId); + LOG.error(errorMsg, exception); + stateManager.setException(detectorId, new TimeSeriesException(errorMsg, exception)); + } + })); + + return previousException; + } + + class EntityResultListener implements ActionListener { + private String nodeId; + private final String adID; + private AtomicReference failure; + + EntityResultListener(String nodeId, String adID, AtomicReference failure) { + this.nodeId = nodeId; + this.adID = adID; + this.failure = failure; + } + + @Override + public void onResponse(AcknowledgedResponse response) { + try { + if (response.isAcknowledged() == false) { + LOG.error("Cannot send entities' features to {} for {}", nodeId, adID); + stateManager.addPressure(nodeId, adID); + } else { + stateManager.resetBackpressureCounter(nodeId, adID); + } + } catch (Exception ex) { + LOG.error("Unexpected exception: {} for {}", ex, adID); + handleException(ex); + } + } + + @Override + public void onFailure(Exception e) { + try { + // e.g., we have connection issues with all of the nodes while restarting clusters + LOG.error(new ParameterizedMessage("Cannot send entities' features to {} for {}", nodeId, adID), e); + + handleException(e); + + } catch (Exception ex) { + LOG.error("Unexpected exception: {} for {}", ex, adID); + handleException(ex); + } + } + + private void handleException(Exception e) { + handlePredictionFailure(e, adID, nodeId, failure); + if (failure.get() != null) { + stateManager.setException(adID, failure.get()); + } + } + } +} diff --git a/src/main/java/org/opensearch/ad/transport/BackPressureRouting.java-e b/src/main/java/org/opensearch/ad/transport/BackPressureRouting.java-e new file mode 100644 index 000000000..e5f4ba9b8 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/BackPressureRouting.java-e @@ -0,0 +1,95 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.time.Clock; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.unit.TimeValue; + +/** + * Data structure to keep track of a node's unresponsive history: a node does not reply for a + * certain consecutive times gets muted for some time. + */ +public class BackPressureRouting { + private static final Logger LOG = LogManager.getLogger(BackPressureRouting.class); + private final String nodeId; + private final Clock clock; + private int maxRetryForUnresponsiveNode; + private TimeValue mutePeriod; + private AtomicInteger backpressureCounter; + private long lastMuteTime; + + public BackPressureRouting(String nodeId, Clock clock, int maxRetryForUnresponsiveNode, TimeValue mutePeriod) { + this.nodeId = nodeId; + this.clock = clock; + this.backpressureCounter = new AtomicInteger(0); + this.maxRetryForUnresponsiveNode = maxRetryForUnresponsiveNode; + this.mutePeriod = mutePeriod; + this.lastMuteTime = 0; + } + + /** + * The caller of this method does not have to keep track of when to start + * muting. This method would mute by itself when we have accumulated enough + * unresponsive calls. + */ + public void addPressure() { + int currentRetry = backpressureCounter.incrementAndGet(); + LOG.info("{} has been unresponsive for {} times", nodeId, currentRetry); + if (currentRetry > this.maxRetryForUnresponsiveNode) { + mute(); + } + } + + /** + * We call this method to decide if a node is muted or not. If yes, we can send + * requests to the node; if not, skip sending requests. + * + * @return whether this node is muted or not + */ + public boolean isMuted() { + if (clock.millis() - lastMuteTime <= mutePeriod.getMillis()) { + return true; + } + return false; + } + + private void mute() { + lastMuteTime = clock.millis(); + } + + public int getMaxRetryForUnresponsiveNode() { + return maxRetryForUnresponsiveNode; + } + + /** + * Setter for maxRetryForUnresponsiveNode + * + * It is up to the client to make the method thread safe. + * + * @param maxRetryForUnresponsiveNode the max retries before muting a node. + */ + public void setMaxRetryForUnresponsiveNode(int maxRetryForUnresponsiveNode) { + this.maxRetryForUnresponsiveNode = maxRetryForUnresponsiveNode; + } + + public TimeValue getMutePeriod() { + return mutePeriod; + } + + public void setMutePeriod(TimeValue mutePeriod) { + this.mutePeriod = mutePeriod; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/CronAction.java-e b/src/main/java/org/opensearch/ad/transport/CronAction.java-e new file mode 100644 index 000000000..1e64a0f45 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/CronAction.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; + +public class CronAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "cron"; + public static final CronAction INSTANCE = new CronAction(); + + private CronAction() { + super(NAME, CronResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/CronNodeRequest.java b/src/main/java/org/opensearch/ad/transport/CronNodeRequest.java index 25d200895..a5362ff46 100644 --- a/src/main/java/org/opensearch/ad/transport/CronNodeRequest.java +++ b/src/main/java/org/opensearch/ad/transport/CronNodeRequest.java @@ -13,7 +13,7 @@ import java.io.IOException; -import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.transport.TransportRequest; /** diff --git a/src/main/java/org/opensearch/ad/transport/CronNodeRequest.java-e b/src/main/java/org/opensearch/ad/transport/CronNodeRequest.java-e new file mode 100644 index 000000000..a5362ff46 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/CronNodeRequest.java-e @@ -0,0 +1,29 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.transport.TransportRequest; + +/** + * Delete model represents the request to an individual node + */ +public class CronNodeRequest extends TransportRequest { + + public CronNodeRequest() {} + + public CronNodeRequest(StreamInput in) throws IOException { + super(in); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/CronNodeResponse.java b/src/main/java/org/opensearch/ad/transport/CronNodeResponse.java index ed28e25d8..f1e9fb0e1 100644 --- a/src/main/java/org/opensearch/ad/transport/CronNodeResponse.java +++ b/src/main/java/org/opensearch/ad/transport/CronNodeResponse.java @@ -15,7 +15,7 @@ import org.opensearch.action.support.nodes.BaseNodeResponse; import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/transport/CronNodeResponse.java-e b/src/main/java/org/opensearch/ad/transport/CronNodeResponse.java-e new file mode 100644 index 000000000..f1e9fb0e1 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/CronNodeResponse.java-e @@ -0,0 +1,45 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +public class CronNodeResponse extends BaseNodeResponse implements ToXContentObject { + static String NODE_ID = "node_id"; + + public CronNodeResponse(StreamInput in) throws IOException { + super(in); + } + + public CronNodeResponse(DiscoveryNode node) { + super(node); + } + + public static CronNodeResponse readNodeResponse(StreamInput in) throws IOException { + + return new CronNodeResponse(in); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(NODE_ID, getNode().getId()); + builder.endObject(); + return builder; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/CronRequest.java b/src/main/java/org/opensearch/ad/transport/CronRequest.java index 6de2814c1..0f91ae676 100644 --- a/src/main/java/org/opensearch/ad/transport/CronRequest.java +++ b/src/main/java/org/opensearch/ad/transport/CronRequest.java @@ -15,7 +15,7 @@ import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamInput; /** * Request should be sent from the handler logic of transport delete detector API diff --git a/src/main/java/org/opensearch/ad/transport/CronRequest.java-e b/src/main/java/org/opensearch/ad/transport/CronRequest.java-e new file mode 100644 index 000000000..0f91ae676 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/CronRequest.java-e @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; + +/** + * Request should be sent from the handler logic of transport delete detector API + * + */ +public class CronRequest extends BaseNodesRequest { + + public CronRequest() { + super((String[]) null); + } + + public CronRequest(StreamInput in) throws IOException { + super(in); + } + + public CronRequest(DiscoveryNode... nodes) { + super(nodes); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/CronResponse.java b/src/main/java/org/opensearch/ad/transport/CronResponse.java index 2a74eedc2..13332c3af 100644 --- a/src/main/java/org/opensearch/ad/transport/CronResponse.java +++ b/src/main/java/org/opensearch/ad/transport/CronResponse.java @@ -17,8 +17,8 @@ import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.nodes.BaseNodesResponse; import org.opensearch.cluster.ClusterName; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/transport/CronResponse.java-e b/src/main/java/org/opensearch/ad/transport/CronResponse.java-e new file mode 100644 index 000000000..13332c3af --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/CronResponse.java-e @@ -0,0 +1,55 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.nodes.BaseNodesResponse; +import org.opensearch.cluster.ClusterName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentFragment; +import org.opensearch.core.xcontent.XContentBuilder; + +public class CronResponse extends BaseNodesResponse implements ToXContentFragment { + static String NODES_JSON_KEY = "nodes"; + + public CronResponse(StreamInput in) throws IOException { + super(in); + } + + public CronResponse(ClusterName clusterName, List nodes, List failures) { + super(clusterName, nodes, failures); + } + + @Override + public List readNodesFrom(StreamInput in) throws IOException { + return in.readList(CronNodeResponse::readNodeResponse); + } + + @Override + protected void writeNodesTo(StreamOutput out, List nodes) throws IOException { + out.writeList(nodes); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startArray(NODES_JSON_KEY); + for (CronNodeResponse nodeResp : getNodes()) { + nodeResp.toXContent(builder, params); + } + builder.endArray(); + return builder; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/CronTransportAction.java b/src/main/java/org/opensearch/ad/transport/CronTransportAction.java index 6a11f4525..edc21cd6f 100644 --- a/src/main/java/org/opensearch/ad/transport/CronTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/CronTransportAction.java @@ -28,7 +28,7 @@ import org.opensearch.ad.task.ADTaskManager; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; diff --git a/src/main/java/org/opensearch/ad/transport/CronTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/CronTransportAction.java-e new file mode 100644 index 000000000..edc21cd6f --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/CronTransportAction.java-e @@ -0,0 +1,130 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.List; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.ml.EntityColdStarter; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class CronTransportAction extends TransportNodesAction { + private final Logger LOG = LogManager.getLogger(CronTransportAction.class); + private NodeStateManager transportStateManager; + private ModelManager modelManager; + private FeatureManager featureManager; + private CacheProvider cacheProvider; + private EntityColdStarter entityColdStarter; + private ADTaskManager adTaskManager; + + @Inject + public CronTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + NodeStateManager tarnsportStatemanager, + ModelManager modelManager, + FeatureManager featureManager, + CacheProvider cacheProvider, + EntityColdStarter entityColdStarter, + ADTaskManager adTaskManager + ) { + super( + CronAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + CronRequest::new, + CronNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + CronNodeResponse.class + ); + this.transportStateManager = tarnsportStatemanager; + this.modelManager = modelManager; + this.featureManager = featureManager; + this.cacheProvider = cacheProvider; + this.entityColdStarter = entityColdStarter; + this.adTaskManager = adTaskManager; + } + + @Override + protected CronResponse newResponse(CronRequest request, List responses, List failures) { + return new CronResponse(clusterService.getClusterName(), responses, failures); + } + + @Override + protected CronNodeRequest newNodeRequest(CronRequest request) { + return new CronNodeRequest(); + } + + @Override + protected CronNodeResponse newNodeResponse(StreamInput in) throws IOException { + return new CronNodeResponse(in); + } + + /** + * Delete unused models and save checkpoints before deleting (including both RCF + * and thresholding model), buffered shingle data, and transport state + * + * @param request delete request + * @return delete response including local node Id. + */ + @Override + protected CronNodeResponse nodeOperation(CronNodeRequest request) { + LOG.info("Start running AD hourly cron."); + // makes checkpoints for hosted models and stop hosting models not actively + // used. + // for single-entity detector + modelManager + .maintenance(ActionListener.wrap(v -> LOG.debug("model maintenance done"), e -> LOG.error("Error maintaining model", e))); + // for multi-entity detector + cacheProvider.get().maintenance(); + + // delete unused buffered shingle data + featureManager.maintenance(); + + // delete unused transport state + transportStateManager.maintenance(); + + entityColdStarter.maintenance(); + // clean child tasks and AD results of deleted detector level task + adTaskManager.cleanChildTasksAndADResultsOfDeletedTask(); + + // clean AD results of deleted detector + adTaskManager.cleanADResultOfDeletedDetector(); + + // maintain running historical tasks: reset task state as stopped if not running and clean stale running entities + adTaskManager.maintainRunningHistoricalTasks(transportService, 100); + + // maintain running realtime tasks: clean stale running realtime task cache + adTaskManager.maintainRunningRealtimeTasks(); + + return new CronNodeResponse(clusterService.localNode()); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorAction.java-e b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorAction.java-e new file mode 100644 index 000000000..75dc34638 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorAction.java-e @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.ad.constant.CommonValue; + +public class DeleteAnomalyDetectorAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/delete"; + public static final DeleteAnomalyDetectorAction INSTANCE = new DeleteAnomalyDetectorAction(); + + private DeleteAnomalyDetectorAction() { + super(NAME, DeleteResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorRequest.java b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorRequest.java index 22686616b..f87b6e0a1 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorRequest.java +++ b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorRequest.java @@ -18,9 +18,9 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; public class DeleteAnomalyDetectorRequest extends ActionRequest { diff --git a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorRequest.java-e b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorRequest.java-e new file mode 100644 index 000000000..913c56e44 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorRequest.java-e @@ -0,0 +1,57 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.Strings; + +public class DeleteAnomalyDetectorRequest extends ActionRequest { + + private String detectorID; + + public DeleteAnomalyDetectorRequest(StreamInput in) throws IOException { + super(in); + this.detectorID = in.readString(); + } + + public DeleteAnomalyDetectorRequest(String detectorID) { + super(); + this.detectorID = detectorID; + } + + public String getDetectorID() { + return detectorID; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(detectorID); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (Strings.isEmpty(detectorID)) { + validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + } + return validationException; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportAction.java index 4af2f5b8d..ebc7577f0 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportAction.java @@ -14,7 +14,7 @@ import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_DELETE_DETECTOR; import static org.opensearch.ad.model.ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES; import static org.opensearch.ad.settings.AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.timeseries.util.ParseUtils.getUserContext; import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; @@ -43,10 +43,10 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; -import org.opensearch.rest.RestStatus; import org.opensearch.tasks.Task; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.function.ExecutorFunction; diff --git a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportAction.java-e new file mode 100644 index 000000000..ce0eba10f --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportAction.java-e @@ -0,0 +1,230 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_DELETE_DETECTOR; +import static org.opensearch.ad.model.ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.util.ParseUtils.getUserContext; +import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import java.io.IOException; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +public class DeleteAnomalyDetectorTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(DeleteAnomalyDetectorTransportAction.class); + private final Client client; + private final ClusterService clusterService; + private final TransportService transportService; + private NamedXContentRegistry xContentRegistry; + private final ADTaskManager adTaskManager; + private volatile Boolean filterByEnabled; + + @Inject + public DeleteAnomalyDetectorTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ClusterService clusterService, + Settings settings, + NamedXContentRegistry xContentRegistry, + ADTaskManager adTaskManager + ) { + super(DeleteAnomalyDetectorAction.NAME, transportService, actionFilters, DeleteAnomalyDetectorRequest::new); + this.transportService = transportService; + this.client = client; + this.clusterService = clusterService; + this.xContentRegistry = xContentRegistry; + this.adTaskManager = adTaskManager; + filterByEnabled = AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); + } + + @Override + protected void doExecute(Task task, DeleteAnomalyDetectorRequest request, ActionListener actionListener) { + String detectorId = request.getDetectorID(); + LOG.info("Delete anomaly detector job {}", detectorId); + User user = getUserContext(client); + ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_DELETE_DETECTOR); + // By the time request reaches here, the user permissions are validated by Security plugin. + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + resolveUserAndExecute( + user, + detectorId, + filterByEnabled, + listener, + (anomalyDetector) -> adTaskManager.getDetector(detectorId, detector -> { + if (!detector.isPresent()) { + // In a mixed cluster, if delete detector request routes to node running AD1.0, then it will + // not delete detector tasks. User can re-delete these deleted detector after cluster upgraded, + // in that case, the detector is not present. + LOG.info("Can't find anomaly detector {}", detectorId); + adTaskManager.deleteADTasks(detectorId, () -> deleteAnomalyDetectorJobDoc(detectorId, listener), listener); + return; + } + // Check if there is realtime job or historical analysis task running. If none of these running, we + // can delete the detector. + getDetectorJob(detectorId, listener, () -> { + adTaskManager.getAndExecuteOnLatestDetectorLevelTask(detectorId, HISTORICAL_DETECTOR_TASK_TYPES, adTask -> { + if (adTask.isPresent() && !adTask.get().isDone()) { + listener.onFailure(new OpenSearchStatusException("Detector is running", RestStatus.INTERNAL_SERVER_ERROR)); + } else { + adTaskManager.deleteADTasks(detectorId, () -> deleteAnomalyDetectorJobDoc(detectorId, listener), listener); + } + }, transportService, true, listener); + }); + }, listener), + client, + clusterService, + xContentRegistry + ); + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + } + + private void deleteAnomalyDetectorJobDoc(String detectorId, ActionListener listener) { + LOG.info("Delete anomaly detector job {}", detectorId); + DeleteRequest deleteRequest = new DeleteRequest(CommonName.JOB_INDEX, detectorId) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.delete(deleteRequest, ActionListener.wrap(response -> { + if (response.getResult() == DocWriteResponse.Result.DELETED || response.getResult() == DocWriteResponse.Result.NOT_FOUND) { + deleteDetectorStateDoc(detectorId, listener); + } else { + String message = "Fail to delete anomaly detector job " + detectorId; + LOG.error(message); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + }, exception -> { + LOG.error("Failed to delete AD job for " + detectorId, exception); + if (exception instanceof IndexNotFoundException) { + deleteDetectorStateDoc(detectorId, listener); + } else { + LOG.error("Failed to delete anomaly detector job", exception); + listener.onFailure(exception); + } + })); + } + + private void deleteDetectorStateDoc(String detectorId, ActionListener listener) { + LOG.info("Delete detector info {}", detectorId); + DeleteRequest deleteRequest = new DeleteRequest(ADCommonName.DETECTION_STATE_INDEX, detectorId); + client + .delete( + deleteRequest, + ActionListener + .wrap( + response -> { + // whether deleted state doc or not, continue as state doc may not exist + deleteAnomalyDetectorDoc(detectorId, listener); + }, + exception -> { + if (exception instanceof IndexNotFoundException) { + deleteAnomalyDetectorDoc(detectorId, listener); + } else { + LOG.error("Failed to delete detector state", exception); + listener.onFailure(exception); + } + } + ) + ); + } + + private void deleteAnomalyDetectorDoc(String detectorId, ActionListener listener) { + LOG.info("Delete anomaly detector {}", detectorId); + DeleteRequest deleteRequest = new DeleteRequest(CommonName.CONFIG_INDEX, detectorId) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.delete(deleteRequest, new ActionListener() { + @Override + public void onResponse(DeleteResponse deleteResponse) { + listener.onResponse(deleteResponse); + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }); + } + + private void getDetectorJob(String detectorId, ActionListener listener, ExecutorFunction function) { + if (clusterService.state().metadata().indices().containsKey(CommonName.JOB_INDEX)) { + GetRequest request = new GetRequest(CommonName.JOB_INDEX).id(detectorId); + client.get(request, ActionListener.wrap(response -> onGetAdJobResponseForWrite(response, listener, function), exception -> { + LOG.error("Fail to get anomaly detector job: " + detectorId, exception); + listener.onFailure(exception); + })); + } else { + function.execute(); + } + } + + private void onGetAdJobResponseForWrite(GetResponse response, ActionListener listener, ExecutorFunction function) + throws IOException { + if (response.isExists()) { + String adJobId = response.getId(); + if (adJobId != null) { + // check if AD job is running on the detector, if yes, we can't delete the detector + try ( + XContentParser parser = RestHandlerUtils + .createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + AnomalyDetectorJob adJob = AnomalyDetectorJob.parse(parser); + if (adJob.isEnabled()) { + listener.onFailure(new OpenSearchStatusException("Detector job is running: " + adJobId, RestStatus.BAD_REQUEST)); + return; + } + } catch (IOException e) { + String message = "Failed to parse anomaly detector job " + adJobId; + LOG.error(message, e); + } + } + } + function.execute(); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsAction.java-e b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsAction.java-e new file mode 100644 index 000000000..ae9de4c95 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsAction.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; +import org.opensearch.index.reindex.BulkByScrollResponse; + +public class DeleteAnomalyResultsAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "results/delete"; + public static final DeleteAnomalyResultsAction INSTANCE = new DeleteAnomalyResultsAction(); + + private DeleteAnomalyResultsAction() { + super(NAME, BulkByScrollResponse::new); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsTransportAction.java-e new file mode 100644 index 000000000..69b12ab0c --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsTransportAction.java-e @@ -0,0 +1,89 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_DELETE_AD_RESULT; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES; +import static org.opensearch.timeseries.util.ParseUtils.addUserBackendRolesFilter; +import static org.opensearch.timeseries.util.ParseUtils.getUserContext; +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.index.reindex.BulkByScrollResponse; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +public class DeleteAnomalyResultsTransportAction extends HandledTransportAction { + + private final Client client; + private volatile Boolean filterEnabled; + private static final Logger logger = LogManager.getLogger(DeleteAnomalyResultsTransportAction.class); + + @Inject + public DeleteAnomalyResultsTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Settings settings, + ClusterService clusterService, + Client client + ) { + super(DeleteAnomalyResultsAction.NAME, transportService, actionFilters, DeleteByQueryRequest::new); + this.client = client; + filterEnabled = FILTER_BY_BACKEND_ROLES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FILTER_BY_BACKEND_ROLES, it -> filterEnabled = it); + } + + @Override + protected void doExecute(Task task, DeleteByQueryRequest request, ActionListener actionListener) { + ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_DELETE_AD_RESULT); + delete(request, listener); + } + + public void delete(DeleteByQueryRequest request, ActionListener listener) { + User user = getUserContext(client); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + validateRole(request, user, listener); + } catch (Exception e) { + logger.error(e); + listener.onFailure(e); + } + } + + private void validateRole(DeleteByQueryRequest request, User user, ActionListener listener) { + if (user == null || !filterEnabled) { + // Case 1: user == null when 1. Security is disabled. 2. When user is super-admin + // Case 2: If Security is enabled and filter is disabled, proceed with search as + // user is already authenticated to hit this API. + client.execute(DeleteByQueryAction.INSTANCE, request, listener); + } else { + // Security is enabled and backend role filter is enabled + try { + addUserBackendRolesFilter(user, request.getSearchRequest().source()); + client.execute(DeleteByQueryAction.INSTANCE, request, listener); + } catch (Exception e) { + listener.onFailure(e); + } + } + } +} diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelAction.java-e b/src/main/java/org/opensearch/ad/transport/DeleteModelAction.java-e new file mode 100644 index 000000000..3af6982b0 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/DeleteModelAction.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; + +public class DeleteModelAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "model/delete"; + public static final DeleteModelAction INSTANCE = new DeleteModelAction(); + + private DeleteModelAction() { + super(NAME, DeleteModelResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelNodeRequest.java b/src/main/java/org/opensearch/ad/transport/DeleteModelNodeRequest.java index a0cfabdd3..d10eef4c3 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelNodeRequest.java +++ b/src/main/java/org/opensearch/ad/transport/DeleteModelNodeRequest.java @@ -13,8 +13,8 @@ import java.io.IOException; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.transport.TransportRequest; /** diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelNodeRequest.java-e b/src/main/java/org/opensearch/ad/transport/DeleteModelNodeRequest.java-e new file mode 100644 index 000000000..d10eef4c3 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/DeleteModelNodeRequest.java-e @@ -0,0 +1,47 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + +/** + * Delete model represents the request to an individual node + */ +public class DeleteModelNodeRequest extends TransportRequest { + + private String adID; + + DeleteModelNodeRequest() {} + + DeleteModelNodeRequest(StreamInput in) throws IOException { + super(in); + this.adID = in.readString(); + } + + DeleteModelNodeRequest(DeleteModelRequest request) { + this.adID = request.getAdID(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(adID); + } + + public String getAdID() { + return adID; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelNodeResponse.java b/src/main/java/org/opensearch/ad/transport/DeleteModelNodeResponse.java index b1fe2eaf9..c71e7368c 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelNodeResponse.java +++ b/src/main/java/org/opensearch/ad/transport/DeleteModelNodeResponse.java @@ -15,7 +15,7 @@ import org.opensearch.action.support.nodes.BaseNodeResponse; import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelNodeResponse.java-e b/src/main/java/org/opensearch/ad/transport/DeleteModelNodeResponse.java-e new file mode 100644 index 000000000..c71e7368c --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/DeleteModelNodeResponse.java-e @@ -0,0 +1,45 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +public class DeleteModelNodeResponse extends BaseNodeResponse implements ToXContentObject { + static String NODE_ID = "node_id"; + + public DeleteModelNodeResponse(StreamInput in) throws IOException { + super(in); + } + + public DeleteModelNodeResponse(DiscoveryNode node) { + super(node); + } + + public static DeleteModelNodeResponse readNodeResponse(StreamInput in) throws IOException { + + return new DeleteModelNodeResponse(in); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(NODE_ID, getNode().getId()); + builder.endObject(); + return builder; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelRequest.java b/src/main/java/org/opensearch/ad/transport/DeleteModelRequest.java index 2710fca8b..9ec58acda 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelRequest.java +++ b/src/main/java/org/opensearch/ad/transport/DeleteModelRequest.java @@ -20,9 +20,9 @@ import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelRequest.java-e b/src/main/java/org/opensearch/ad/transport/DeleteModelRequest.java-e new file mode 100644 index 000000000..ed54649e2 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/DeleteModelRequest.java-e @@ -0,0 +1,76 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; + +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +/** + * Request should be sent from the handler logic of transport delete detector API + * + */ +public class DeleteModelRequest extends BaseNodesRequest implements ToXContentObject { + private String adID; + + public String getAdID() { + return adID; + } + + public DeleteModelRequest() { + super((String[]) null); + } + + public DeleteModelRequest(StreamInput in) throws IOException { + super(in); + this.adID = in.readString(); + } + + public DeleteModelRequest(String adID, DiscoveryNode... nodes) { + super(nodes); + this.adID = adID; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(adID); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (Strings.isEmpty(adID)) { + validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + } + return validationException; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ADCommonName.ID_JSON_KEY, adID); + builder.endObject(); + return builder; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelResponse.java b/src/main/java/org/opensearch/ad/transport/DeleteModelResponse.java index 1b4b0731f..f2cbe2468 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelResponse.java +++ b/src/main/java/org/opensearch/ad/transport/DeleteModelResponse.java @@ -17,8 +17,8 @@ import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.nodes.BaseNodesResponse; import org.opensearch.cluster.ClusterName; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelResponse.java-e b/src/main/java/org/opensearch/ad/transport/DeleteModelResponse.java-e new file mode 100644 index 000000000..f2cbe2468 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/DeleteModelResponse.java-e @@ -0,0 +1,55 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.nodes.BaseNodesResponse; +import org.opensearch.cluster.ClusterName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentFragment; +import org.opensearch.core.xcontent.XContentBuilder; + +public class DeleteModelResponse extends BaseNodesResponse implements ToXContentFragment { + static String NODES_JSON_KEY = "nodes"; + + public DeleteModelResponse(StreamInput in) throws IOException { + super(in); + } + + public DeleteModelResponse(ClusterName clusterName, List nodes, List failures) { + super(clusterName, nodes, failures); + } + + @Override + public List readNodesFrom(StreamInput in) throws IOException { + return in.readList(DeleteModelNodeResponse::readNodeResponse); + } + + @Override + protected void writeNodesTo(StreamOutput out, List nodes) throws IOException { + out.writeList(nodes); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startArray(NODES_JSON_KEY); + for (DeleteModelNodeResponse nodeResp : getNodes()) { + nodeResp.toXContent(builder, params); + } + builder.endArray(); + return builder; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelTransportAction.java b/src/main/java/org/opensearch/ad/transport/DeleteModelTransportAction.java index 839e32666..b7a3bee88 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/DeleteModelTransportAction.java @@ -28,7 +28,7 @@ import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/DeleteModelTransportAction.java-e new file mode 100644 index 000000000..b7a3bee88 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/DeleteModelTransportAction.java-e @@ -0,0 +1,137 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.List; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.ml.EntityColdStarter; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.task.ADTaskCacheManager; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class DeleteModelTransportAction extends + TransportNodesAction { + private static final Logger LOG = LogManager.getLogger(DeleteModelTransportAction.class); + private NodeStateManager nodeStateManager; + private ModelManager modelManager; + private FeatureManager featureManager; + private CacheProvider cache; + private ADTaskCacheManager adTaskCacheManager; + private EntityColdStarter coldStarter; + + @Inject + public DeleteModelTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + NodeStateManager nodeStateManager, + ModelManager modelManager, + FeatureManager featureManager, + CacheProvider cache, + ADTaskCacheManager adTaskCacheManager, + EntityColdStarter coldStarter + ) { + super( + DeleteModelAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + DeleteModelRequest::new, + DeleteModelNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + DeleteModelNodeResponse.class + ); + this.nodeStateManager = nodeStateManager; + this.modelManager = modelManager; + this.featureManager = featureManager; + this.cache = cache; + this.adTaskCacheManager = adTaskCacheManager; + this.coldStarter = coldStarter; + } + + @Override + protected DeleteModelResponse newResponse( + DeleteModelRequest request, + List responses, + List failures + ) { + return new DeleteModelResponse(clusterService.getClusterName(), responses, failures); + } + + @Override + protected DeleteModelNodeRequest newNodeRequest(DeleteModelRequest request) { + return new DeleteModelNodeRequest(request); + } + + @Override + protected DeleteModelNodeResponse newNodeResponse(StreamInput in) throws IOException { + return new DeleteModelNodeResponse(in); + } + + /** + * + * Delete checkpoint document (including both RCF and thresholding model), in-memory models, + * buffered shingle data, transport state, and anomaly result + * + * @param request delete request + * @return delete response including local node Id. + */ + @Override + protected DeleteModelNodeResponse nodeOperation(DeleteModelNodeRequest request) { + + String adID = request.getAdID(); + LOG.info("Delete model for {}", adID); + // delete in-memory models and model checkpoint + modelManager + .clear( + adID, + ActionListener + .wrap( + r -> LOG.info("Deleted model for [{}] with response [{}] ", adID, r), + e -> LOG.error("Fail to delete model for " + adID, e) + ) + ); + + // delete buffered shingle data + featureManager.clear(adID); + + // delete transport state + nodeStateManager.clear(adID); + + cache.get().clear(adID); + + coldStarter.clear(adID); + + // delete realtime task cache + adTaskCacheManager.removeRealtimeTaskCache(adID); + + LOG.info("Finished deleting {}", adID); + return new DeleteModelNodeResponse(clusterService.localNode()); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/EntityProfileAction.java-e b/src/main/java/org/opensearch/ad/transport/EntityProfileAction.java-e new file mode 100644 index 000000000..c699d9a03 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/EntityProfileAction.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; + +public class EntityProfileAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detectors/profile/entity"; + public static final EntityProfileAction INSTANCE = new EntityProfileAction(); + + private EntityProfileAction() { + super(NAME, EntityProfileResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/EntityProfileRequest.java b/src/main/java/org/opensearch/ad/transport/EntityProfileRequest.java index a388b33c5..7e4054a8a 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityProfileRequest.java +++ b/src/main/java/org/opensearch/ad/transport/EntityProfileRequest.java @@ -22,9 +22,9 @@ import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.EntityProfileName; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.model.Entity; diff --git a/src/main/java/org/opensearch/ad/transport/EntityProfileRequest.java-e b/src/main/java/org/opensearch/ad/transport/EntityProfileRequest.java-e new file mode 100644 index 000000000..fc9274f5b --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/EntityProfileRequest.java-e @@ -0,0 +1,109 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Set; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.EntityProfileName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.model.Entity; + +public class EntityProfileRequest extends ActionRequest implements ToXContentObject { + public static final String ENTITY = "entity"; + public static final String PROFILES = "profiles"; + private String adID; + // changed from String to Entity since 1.1 + private Entity entityValue; + private Set profilesToCollect; + + public EntityProfileRequest(StreamInput in) throws IOException { + super(in); + adID = in.readString(); + entityValue = new Entity(in); + + int size = in.readVInt(); + profilesToCollect = new HashSet(); + if (size != 0) { + for (int i = 0; i < size; i++) { + profilesToCollect.add(in.readEnum(EntityProfileName.class)); + } + } + } + + public EntityProfileRequest(String adID, Entity entityValue, Set profilesToCollect) { + super(); + this.adID = adID; + this.entityValue = entityValue; + this.profilesToCollect = profilesToCollect; + } + + public String getAdID() { + return adID; + } + + public Entity getEntityValue() { + return entityValue; + } + + public Set getProfilesToCollect() { + return profilesToCollect; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(adID); + entityValue.writeTo(out); + + out.writeVInt(profilesToCollect.size()); + for (EntityProfileName profile : profilesToCollect) { + out.writeEnum(profile); + } + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (Strings.isEmpty(adID)) { + validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + } + if (entityValue == null) { + validationException = addValidationError("Entity value is missing", validationException); + } + if (profilesToCollect == null || profilesToCollect.isEmpty()) { + validationException = addValidationError(ADCommonMessages.EMPTY_PROFILES_COLLECT, validationException); + } + return validationException; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ADCommonName.ID_JSON_KEY, adID); + builder.field(ENTITY, entityValue); + builder.field(PROFILES, profilesToCollect); + builder.endObject(); + return builder; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/EntityProfileResponse.java b/src/main/java/org/opensearch/ad/transport/EntityProfileResponse.java index 36e7c6f6b..96546c3e3 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityProfileResponse.java +++ b/src/main/java/org/opensearch/ad/transport/EntityProfileResponse.java @@ -20,8 +20,8 @@ import org.opensearch.action.ActionResponse; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.ModelProfileOnNode; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/transport/EntityProfileResponse.java-e b/src/main/java/org/opensearch/ad/transport/EntityProfileResponse.java-e new file mode 100644 index 000000000..96546c3e3 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/EntityProfileResponse.java-e @@ -0,0 +1,173 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.Optional; + +import org.apache.commons.lang.builder.EqualsBuilder; +import org.apache.commons.lang.builder.HashCodeBuilder; +import org.apache.commons.lang.builder.ToStringBuilder; +import org.opensearch.action.ActionResponse; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.ModelProfileOnNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +public class EntityProfileResponse extends ActionResponse implements ToXContentObject { + public static final String ACTIVE = "active"; + public static final String LAST_ACTIVE_TS = "last_active_timestamp"; + public static final String TOTAL_UPDATES = "total_updates"; + private final Boolean isActive; + private final long lastActiveMs; + private final long totalUpdates; + private final ModelProfileOnNode modelProfile; + + public static class Builder { + private Boolean isActive = null; + private long lastActiveMs = -1L; + private long totalUpdates = -1L; + private ModelProfileOnNode modelProfile = null; + + public Builder() {} + + public Builder setActive(Boolean isActive) { + this.isActive = isActive; + return this; + } + + public Builder setLastActiveMs(long lastActiveMs) { + this.lastActiveMs = lastActiveMs; + return this; + } + + public Builder setTotalUpdates(long totalUpdates) { + this.totalUpdates = totalUpdates; + return this; + } + + public Builder setModelProfile(ModelProfileOnNode modelProfile) { + this.modelProfile = modelProfile; + return this; + } + + public EntityProfileResponse build() { + return new EntityProfileResponse(isActive, lastActiveMs, totalUpdates, modelProfile); + } + } + + public EntityProfileResponse(Boolean isActive, long lastActiveTimeMs, long totalUpdates, ModelProfileOnNode modelProfile) { + this.isActive = isActive; + this.lastActiveMs = lastActiveTimeMs; + this.totalUpdates = totalUpdates; + this.modelProfile = modelProfile; + } + + public EntityProfileResponse(StreamInput in) throws IOException { + super(in); + isActive = in.readOptionalBoolean(); + lastActiveMs = in.readLong(); + totalUpdates = in.readLong(); + if (in.readBoolean()) { + modelProfile = new ModelProfileOnNode(in); + } else { + modelProfile = null; + } + } + + public Optional isActive() { + return Optional.ofNullable(isActive); + } + + public long getLastActiveMs() { + return lastActiveMs; + } + + public long getTotalUpdates() { + return totalUpdates; + } + + public ModelProfileOnNode getModelProfile() { + return modelProfile; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalBoolean(isActive); + out.writeLong(lastActiveMs); + out.writeLong(totalUpdates); + if (modelProfile != null) { + out.writeBoolean(true); + modelProfile.writeTo(out); + } else { + out.writeBoolean(false); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (isActive != null) { + builder.field(ACTIVE, isActive); + } + if (lastActiveMs >= 0) { + builder.field(LAST_ACTIVE_TS, lastActiveMs); + } + if (totalUpdates >= 0) { + builder.field(TOTAL_UPDATES, totalUpdates); + } + if (modelProfile != null) { + builder.field(ADCommonName.MODEL, modelProfile); + } + builder.endObject(); + return builder; + } + + @Override + public String toString() { + ToStringBuilder builder = new ToStringBuilder(this); + builder.append(ACTIVE, isActive); + builder.append(LAST_ACTIVE_TS, lastActiveMs); + builder.append(TOTAL_UPDATES, totalUpdates); + builder.append(ADCommonName.MODEL, modelProfile); + + return builder.toString(); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + if (obj instanceof EntityProfileResponse) { + EntityProfileResponse other = (EntityProfileResponse) obj; + EqualsBuilder equalsBuilder = new EqualsBuilder(); + equalsBuilder.append(isActive, other.isActive); + equalsBuilder.append(lastActiveMs, other.lastActiveMs); + equalsBuilder.append(totalUpdates, other.totalUpdates); + equalsBuilder.append(modelProfile, other.modelProfile); + + return equalsBuilder.isEquals(); + } + return false; + } + + @Override + public int hashCode() { + return new HashCodeBuilder().append(isActive).append(lastActiveMs).append(totalUpdates).append(modelProfile).toHashCode(); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/EntityProfileTransportAction.java b/src/main/java/org/opensearch/ad/transport/EntityProfileTransportAction.java index b40976ec2..0a124360d 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityProfileTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/EntityProfileTransportAction.java @@ -31,8 +31,8 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.common.exception.TimeSeriesException; diff --git a/src/main/java/org/opensearch/ad/transport/EntityProfileTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/EntityProfileTransportAction.java-e new file mode 100644 index 000000000..c9d56b3e6 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/EntityProfileTransportAction.java-e @@ -0,0 +1,180 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.Locale; +import java.util.Optional; +import java.util.Set; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.model.EntityProfileName; +import org.opensearch.ad.model.ModelProfile; +import org.opensearch.ad.model.ModelProfileOnNode; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.settings.Settings; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.transport.TransportException; +import org.opensearch.transport.TransportRequestOptions; +import org.opensearch.transport.TransportResponseHandler; +import org.opensearch.transport.TransportService; + +/** + * Transport action to get entity profile. + */ +public class EntityProfileTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(EntityProfileTransportAction.class); + public static final String NO_NODE_FOUND_MSG = "Cannot find model hosting node"; + public static final String NO_MODEL_ID_FOUND_MSG = "Cannot find model id"; + static final String FAIL_TO_GET_ENTITY_PROFILE_MSG = "Cannot get entity profile info"; + + private final TransportService transportService; + private final HashRing hashRing; + private final TransportRequestOptions option; + private final ClusterService clusterService; + private final CacheProvider cacheProvider; + + @Inject + public EntityProfileTransportAction( + ActionFilters actionFilters, + TransportService transportService, + Settings settings, + HashRing hashRing, + ClusterService clusterService, + CacheProvider cacheProvider + ) { + super(EntityProfileAction.NAME, transportService, actionFilters, EntityProfileRequest::new); + this.transportService = transportService; + this.hashRing = hashRing; + this.option = TransportRequestOptions + .builder() + .withType(TransportRequestOptions.Type.REG) + .withTimeout(AnomalyDetectorSettings.REQUEST_TIMEOUT.get(settings)) + .build(); + this.clusterService = clusterService; + this.cacheProvider = cacheProvider; + } + + @Override + protected void doExecute(Task task, EntityProfileRequest request, ActionListener listener) { + + String adID = request.getAdID(); + Entity entityValue = request.getEntityValue(); + Optional modelIdOptional = entityValue.getModelId(adID); + if (false == modelIdOptional.isPresent()) { + listener.onFailure(new TimeSeriesException(adID, NO_MODEL_ID_FOUND_MSG)); + return; + } + // we use entity's toString (e.g., app_0) to find its node + // This should be consistent with how we land a model node in AnomalyResultTransportAction + Optional node = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(entityValue.toString()); + if (false == node.isPresent()) { + listener.onFailure(new TimeSeriesException(adID, NO_NODE_FOUND_MSG)); + return; + } + String nodeId = node.get().getId(); + String modelId = modelIdOptional.get(); + DiscoveryNode localNode = clusterService.localNode(); + if (localNode.getId().equals(nodeId)) { + EntityCache cache = cacheProvider.get(); + Set profilesToCollect = request.getProfilesToCollect(); + EntityProfileResponse.Builder builder = new EntityProfileResponse.Builder(); + if (profilesToCollect.contains(EntityProfileName.ENTITY_INFO)) { + builder.setActive(cache.isActive(adID, modelId)); + builder.setLastActiveMs(cache.getLastActiveMs(adID, modelId)); + } + if (profilesToCollect.contains(EntityProfileName.INIT_PROGRESS) || profilesToCollect.contains(EntityProfileName.STATE)) { + builder.setTotalUpdates(cache.getTotalUpdates(adID, modelId)); + } + if (profilesToCollect.contains(EntityProfileName.MODELS)) { + Optional modleProfile = cache.getModelProfile(adID, modelId); + if (modleProfile.isPresent()) { + builder.setModelProfile(new ModelProfileOnNode(nodeId, modleProfile.get())); + } + } + listener.onResponse(builder.build()); + } else if (request.remoteAddress() == null) { + // redirect if request comes from local host. + // If a request comes from remote machine, it is already redirected. + // One redirection should be enough. + // We don't want a potential infinite loop due to any bug and thus give up. + LOG.info("Sending entity profile request to {} for detector {}, entity {}", nodeId, adID, entityValue); + + try { + transportService + .sendRequest( + node.get(), + EntityProfileAction.NAME, + request, + option, + new TransportResponseHandler() { + + @Override + public EntityProfileResponse read(StreamInput in) throws IOException { + return new EntityProfileResponse(in); + } + + @Override + public void handleResponse(EntityProfileResponse response) { + listener.onResponse(response); + } + + @Override + public void handleException(TransportException exp) { + listener.onFailure(exp); + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + + } + ); + } catch (Exception e) { + LOG.error(String.format(Locale.ROOT, "Fail to get entity profile for detector {}, entity {}", adID, entityValue), e); + listener.onFailure(new TimeSeriesException(adID, FAIL_TO_GET_ENTITY_PROFILE_MSG, e)); + } + + } else { + // Prior to Opensearch 1.1, we map a node using model id in the profile API. + // This is not consistent how we map node in AnomalyResultTransportAction, where + // we use entity values. We fixed that bug in Opensearch 1.1. But this can cause + // issue when a request involving an old node according to model id. + // The new node finds the entity value does not map to itself, so it redirects to another node. + // The redirection can cause an infinite loop. This branch breaks the loop and gives up. + LOG + .error( + "Fail to get entity profile for detector {}, entity {}. Maybe because old and new node" + + " are using different keys for the hash ring.", + adID, + entityValue + ); + listener.onFailure(new TimeSeriesException(adID, FAIL_TO_GET_ENTITY_PROFILE_MSG)); + } + } +} diff --git a/src/main/java/org/opensearch/ad/transport/EntityResultAction.java-e b/src/main/java/org/opensearch/ad/transport/EntityResultAction.java-e new file mode 100644 index 000000000..c519858b4 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/EntityResultAction.java-e @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.ad.constant.CommonValue; + +public class EntityResultAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "entity/result"; + public static final EntityResultAction INSTANCE = new EntityResultAction(); + + private EntityResultAction() { + super(NAME, AcknowledgedResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/EntityResultRequest.java b/src/main/java/org/opensearch/ad/transport/EntityResultRequest.java index a2c6785ee..91041f447 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityResultRequest.java +++ b/src/main/java/org/opensearch/ad/transport/EntityResultRequest.java @@ -23,9 +23,9 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.constant.CommonMessages; diff --git a/src/main/java/org/opensearch/ad/transport/EntityResultRequest.java-e b/src/main/java/org/opensearch/ad/transport/EntityResultRequest.java-e new file mode 100644 index 000000000..98596288f --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/EntityResultRequest.java-e @@ -0,0 +1,125 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; +import java.util.Locale; +import java.util.Map; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Entity; + +public class EntityResultRequest extends ActionRequest implements ToXContentObject { + private static final Logger LOG = LogManager.getLogger(EntityResultRequest.class); + private String detectorId; + // changed from Map to Map + private Map entities; + private long start; + private long end; + + public EntityResultRequest(StreamInput in) throws IOException { + super(in); + this.detectorId = in.readString(); + + // guarded with version check. Just in case we receive requests from older node where we use String + // to represent an entity + this.entities = in.readMap(Entity::new, StreamInput::readDoubleArray); + + this.start = in.readLong(); + this.end = in.readLong(); + } + + public EntityResultRequest(String detectorId, Map entities, long start, long end) { + super(); + this.detectorId = detectorId; + this.entities = entities; + this.start = start; + this.end = end; + } + + public String getId() { + return this.detectorId; + } + + public Map getEntities() { + return this.entities; + } + + public long getStart() { + return this.start; + } + + public long getEnd() { + return this.end; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.detectorId); + // guarded with version check. Just in case we send requests to older node where we use String + // to represent an entity + out.writeMap(entities, (s, e) -> e.writeTo(s), StreamOutput::writeDoubleArray); + + out.writeLong(this.start); + out.writeLong(this.end); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (Strings.isEmpty(detectorId)) { + validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + } + if (start <= 0 || end <= 0 || start > end) { + validationException = addValidationError( + String.format(Locale.ROOT, "%s: start %d, end %d", CommonMessages.INVALID_TIMESTAMP_ERR_MSG, start, end), + validationException + ); + } + return validationException; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ADCommonName.ID_JSON_KEY, detectorId); + builder.field(CommonName.START_JSON_KEY, start); + builder.field(CommonName.END_JSON_KEY, end); + builder.startArray(CommonName.ENTITIES_JSON_KEY); + for (final Map.Entry entry : entities.entrySet()) { + if (entry.getKey() != null) { + builder.startObject(); + builder.field(CommonName.ENTITY_KEY, entry.getKey()); + builder.field(CommonName.VALUE_JSON_KEY, entry.getValue()); + builder.endObject(); + } + } + builder.endArray(); + builder.endObject(); + return builder; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/EntityResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/EntityResultTransportAction.java index 0927a88b1..fd48b302b 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityResultTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/EntityResultTransportAction.java @@ -27,7 +27,6 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.action.support.master.AcknowledgedResponse; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.ad.caching.CacheProvider; @@ -52,6 +51,7 @@ import org.opensearch.common.inject.Inject; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.constant.CommonMessages; @@ -127,7 +127,9 @@ public EntityResultTransportAction( @Override protected void doExecute(Task task, EntityResultRequest request, ActionListener listener) { if (adCircuitBreakerService.isOpen()) { - threadPool.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME).execute(() -> cache.get().releaseMemoryForOpenCircuitBreaker()); + threadPool + .executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME) + .execute(() -> cache.get().releaseMemoryForOpenCircuitBreaker()); listener.onFailure(new LimitExceededException(request.getId(), CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, false)); return; } diff --git a/src/main/java/org/opensearch/ad/transport/EntityResultTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/EntityResultTransportAction.java-e new file mode 100644 index 000000000..fd48b302b --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/EntityResultTransportAction.java-e @@ -0,0 +1,354 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.ratelimit.CheckpointReadWorker; +import org.opensearch.ad.ratelimit.ColdEntityWorker; +import org.opensearch.ad.ratelimit.EntityColdStartWorker; +import org.opensearch.ad.ratelimit.EntityFeatureRequest; +import org.opensearch.ad.ratelimit.RequestPriority; +import org.opensearch.ad.ratelimit.ResultWriteRequest; +import org.opensearch.ad.ratelimit.ResultWriteWorker; +import org.opensearch.ad.stats.ADStats; +import org.opensearch.ad.util.ExceptionUtil; +import org.opensearch.common.inject.Inject; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.transport.TransportService; + +/** + * Entry-point for HCAD workflow. We have created multiple queues for coordinating + * the workflow. The overrall workflow is: + * 1. We store as many frequently used entity models in a cache as allowed by the + * memory limit (10% heap). If an entity feature is a hit, we use the in-memory model + * to detect anomalies and record results using the result write queue. + * 2. If an entity feature is a miss, we check if there is free memory or any other + * entity's model can be evacuated. An in-memory entity's frequency may be lower + * compared to the cache miss entity. If that's the case, we replace the lower + * frequency entity's model with the higher frequency entity's model. To load the + * higher frequency entity's model, we first check if a model exists on disk by + * sending a checkpoint read queue request. If there is a checkpoint, we load it + * to memory, perform detection, and save the result using the result write queue. + * Otherwise, we enqueue a cold start request to the cold start queue for model + * training. If training is successful, we save the learned model via the checkpoint + * write queue. + * 3. We also have the cold entity queue configured for cold entities, and the model + * training and inference are connected by serial juxtaposition to limit resource usage. + */ +public class EntityResultTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(EntityResultTransportAction.class); + private ModelManager modelManager; + private ADCircuitBreakerService adCircuitBreakerService; + private CacheProvider cache; + private final NodeStateManager stateManager; + private ADIndexManagement indexUtil; + private ResultWriteWorker resultWriteQueue; + private CheckpointReadWorker checkpointReadQueue; + private ColdEntityWorker coldEntityQueue; + private ThreadPool threadPool; + private EntityColdStartWorker entityColdStartWorker; + private ADStats adStats; + + @Inject + public EntityResultTransportAction( + ActionFilters actionFilters, + TransportService transportService, + ModelManager manager, + ADCircuitBreakerService adCircuitBreakerService, + CacheProvider entityCache, + NodeStateManager stateManager, + ADIndexManagement indexUtil, + ResultWriteWorker resultWriteQueue, + CheckpointReadWorker checkpointReadQueue, + ColdEntityWorker coldEntityQueue, + ThreadPool threadPool, + EntityColdStartWorker entityColdStartWorker, + ADStats adStats + ) { + super(EntityResultAction.NAME, transportService, actionFilters, EntityResultRequest::new); + this.modelManager = manager; + this.adCircuitBreakerService = adCircuitBreakerService; + this.cache = entityCache; + this.stateManager = stateManager; + this.indexUtil = indexUtil; + this.resultWriteQueue = resultWriteQueue; + this.checkpointReadQueue = checkpointReadQueue; + this.coldEntityQueue = coldEntityQueue; + this.threadPool = threadPool; + this.entityColdStartWorker = entityColdStartWorker; + this.adStats = adStats; + } + + @Override + protected void doExecute(Task task, EntityResultRequest request, ActionListener listener) { + if (adCircuitBreakerService.isOpen()) { + threadPool + .executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME) + .execute(() -> cache.get().releaseMemoryForOpenCircuitBreaker()); + listener.onFailure(new LimitExceededException(request.getId(), CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, false)); + return; + } + + try { + String detectorId = request.getId(); + + Optional previousException = stateManager.fetchExceptionAndClear(detectorId); + + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error("Previous exception of {}: {}", detectorId, exception); + if (exception instanceof EndRunException) { + EndRunException endRunException = (EndRunException) exception; + if (endRunException.isEndNow()) { + listener.onFailure(exception); + return; + } + } + + listener = ExceptionUtil.wrapListener(listener, exception, detectorId); + } + + stateManager.getAnomalyDetector(detectorId, onGetDetector(listener, detectorId, request, previousException)); + } catch (Exception exception) { + LOG.error("fail to get entity's anomaly grade", exception); + listener.onFailure(exception); + } + } + + private ActionListener> onGetDetector( + ActionListener listener, + String detectorId, + EntityResultRequest request, + Optional prevException + ) { + return ActionListener.wrap(detectorOptional -> { + if (!detectorOptional.isPresent()) { + listener.onFailure(new EndRunException(detectorId, "AnomalyDetector is not available.", false)); + return; + } + + AnomalyDetector detector = detectorOptional.get(); + + if (request.getEntities() == null) { + listener.onFailure(new EndRunException(detectorId, "Fail to get any entities from request.", false)); + return; + } + + Instant executionStartTime = Instant.now(); + Map cacheMissEntities = new HashMap<>(); + for (Entry entityEntry : request.getEntities().entrySet()) { + Entity categoricalValues = entityEntry.getKey(); + + if (isEntityFromOldNodeMsg(categoricalValues) + && detector.getCategoryFields() != null + && detector.getCategoryFields().size() == 1) { + Map attrValues = categoricalValues.getAttributes(); + // handle a request from a version before OpenSearch 1.1. + categoricalValues = Entity + .createSingleAttributeEntity(detector.getCategoryFields().get(0), attrValues.get(ADCommonName.EMPTY_FIELD)); + } + + Optional modelIdOptional = categoricalValues.getModelId(detectorId); + if (false == modelIdOptional.isPresent()) { + continue; + } + + String modelId = modelIdOptional.get(); + double[] datapoint = entityEntry.getValue(); + ModelState entityModel = cache.get().get(modelId, detector); + if (entityModel == null) { + // cache miss + cacheMissEntities.put(categoricalValues, datapoint); + continue; + } + try { + ThresholdingResult result = modelManager + .getAnomalyResultForEntity(datapoint, entityModel, modelId, categoricalValues, detector.getShingleSize()); + // result.getRcfScore() = 0 means the model is not initialized + // result.getGrade() = 0 means it is not an anomaly + // So many OpenSearchRejectedExecutionException if we write no matter what + if (result.getRcfScore() > 0) { + List resultsToSave = result + .toIndexableResults( + detector, + Instant.ofEpochMilli(request.getStart()), + Instant.ofEpochMilli(request.getEnd()), + executionStartTime, + Instant.now(), + ParseUtils.getFeatureData(datapoint, detector), + Optional.ofNullable(categoricalValues), + indexUtil.getSchemaVersion(ADIndex.RESULT), + modelId, + null, + null + ); + for (AnomalyResult r : resultsToSave) { + resultWriteQueue + .put( + new ResultWriteRequest( + System.currentTimeMillis() + detector.getIntervalInMilliseconds(), + detectorId, + result.getGrade() > 0 ? RequestPriority.HIGH : RequestPriority.MEDIUM, + r, + detector.getCustomResultIndex() + ) + ); + } + } + } catch (IllegalArgumentException e) { + // fail to score likely due to model corruption. Re-cold start to recover. + LOG.error(new ParameterizedMessage("Likely model corruption for [{}]", modelId), e); + adStats.getStat(StatNames.MODEL_CORRUTPION_COUNT.getName()).increment(); + cache.get().removeEntityModel(detectorId, modelId); + entityColdStartWorker + .put( + new EntityFeatureRequest( + System.currentTimeMillis() + detector.getIntervalInMilliseconds(), + detectorId, + RequestPriority.MEDIUM, + categoricalValues, + datapoint, + request.getStart() + ) + ); + } + } + + // split hot and cold entities + Pair, List> hotColdEntities = cache + .get() + .selectUpdateCandidate(cacheMissEntities.keySet(), detectorId, detector); + + List hotEntityRequests = new ArrayList<>(); + List coldEntityRequests = new ArrayList<>(); + + for (Entity hotEntity : hotColdEntities.getLeft()) { + double[] hotEntityValue = cacheMissEntities.get(hotEntity); + if (hotEntityValue == null) { + LOG.error(new ParameterizedMessage("feature value should not be null: [{}]", hotEntity)); + continue; + } + hotEntityRequests + .add( + new EntityFeatureRequest( + System.currentTimeMillis() + detector.getIntervalInMilliseconds(), + detectorId, + // hot entities has MEDIUM priority + RequestPriority.MEDIUM, + hotEntity, + hotEntityValue, + request.getStart() + ) + ); + } + + for (Entity coldEntity : hotColdEntities.getRight()) { + double[] coldEntityValue = cacheMissEntities.get(coldEntity); + if (coldEntityValue == null) { + LOG.error(new ParameterizedMessage("feature value should not be null: [{}]", coldEntity)); + continue; + } + coldEntityRequests + .add( + new EntityFeatureRequest( + System.currentTimeMillis() + detector.getIntervalInMilliseconds(), + detectorId, + // cold entities has LOW priority + RequestPriority.LOW, + coldEntity, + coldEntityValue, + request.getStart() + ) + ); + } + + checkpointReadQueue.putAll(hotEntityRequests); + coldEntityQueue.putAll(coldEntityRequests); + + // respond back + if (prevException.isPresent()) { + listener.onFailure(prevException.get()); + } else { + listener.onResponse(new AcknowledgedResponse(true)); + } + }, exception -> { + LOG + .error( + new ParameterizedMessage( + "fail to get entity's anomaly grade for detector [{}]: start: [{}], end: [{}]", + detectorId, + request.getStart(), + request.getEnd() + ), + exception + ); + listener.onFailure(exception); + }); + } + + /** + * Whether the received entity comes from an node that doesn't support multi-category fields. + * This can happen during rolling-upgrade or blue/green deployment. + * + * Specifically, when receiving an EntityResultRequest from an incompatible node, + * EntityResultRequest(StreamInput in) gets an String that represents an entity. + * But Entity class requires both an category field name and value. Since we + * don't have access to detector config in EntityResultRequest(StreamInput in), + * we put CommonName.EMPTY_FIELD as the placeholder. In this method, + * we use the same CommonName.EMPTY_FIELD to check if the deserialized entity + * comes from an incompatible node. If it is, we will add the field name back + * as EntityResultTranportAction has access to the detector config object. + * + * @param categoricalValues deserialized Entity from inbound message. + * @return Whether the received entity comes from an node that doesn't support multi-category fields. + */ + private boolean isEntityFromOldNodeMsg(Entity categoricalValues) { + Map attrValues = categoricalValues.getAttributes(); + return (attrValues != null && attrValues.containsKey(ADCommonName.EMPTY_FIELD)); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ForwardADTaskAction.java-e b/src/main/java/org/opensearch/ad/transport/ForwardADTaskAction.java-e new file mode 100644 index 000000000..309714cc8 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ForwardADTaskAction.java-e @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.constant.ADCommonName.AD_TASK; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; + +public class ForwardADTaskAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detector/" + AD_TASK + "/forward"; + public static final ForwardADTaskAction INSTANCE = new ForwardADTaskAction(); + + private ForwardADTaskAction() { + super(NAME, AnomalyDetectorJobResponse::new); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ForwardADTaskRequest.java b/src/main/java/org/opensearch/ad/transport/ForwardADTaskRequest.java index 7e2c9ea70..417696609 100644 --- a/src/main/java/org/opensearch/ad/transport/ForwardADTaskRequest.java +++ b/src/main/java/org/opensearch/ad/transport/ForwardADTaskRequest.java @@ -24,9 +24,9 @@ import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskAction; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.timeseries.common.exception.VersionException; import org.opensearch.timeseries.function.ExecutorFunction; import org.opensearch.timeseries.model.DateRange; diff --git a/src/main/java/org/opensearch/ad/transport/ForwardADTaskRequest.java-e b/src/main/java/org/opensearch/ad/transport/ForwardADTaskRequest.java-e new file mode 100644 index 000000000..1ad231ec5 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ForwardADTaskRequest.java-e @@ -0,0 +1,218 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import org.opensearch.Version; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskAction; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.commons.authuser.User; +import org.opensearch.timeseries.common.exception.VersionException; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.transport.TransportService; + +public class ForwardADTaskRequest extends ActionRequest { + private AnomalyDetector detector; + private ADTask adTask; + private DateRange detectionDateRange; + private List staleRunningEntities; + private User user; + private Integer availableTaskSlots; + private ADTaskAction adTaskAction; + + /** + * Constructor function. + * For most task actions, we only send ForwardADTaskRequest to node with same local AD version. + * But it's possible that we need to clean up detector cache by sending FINISHED task action to + * an old coordinating node when no task running for the detector. + * Check {@link org.opensearch.ad.task.ADTaskManager#cleanDetectorCache(ADTask, TransportService, ExecutorFunction)}. + * + * @param detector detector + * @param detectionDateRange detection date range + * @param user user + * @param adTaskAction AD task action + * @param availableTaskSlots available task slots + * @param remoteAdVersion AD version of remote node + */ + public ForwardADTaskRequest( + AnomalyDetector detector, + DateRange detectionDateRange, + User user, + ADTaskAction adTaskAction, + Integer availableTaskSlots, + Version remoteAdVersion + ) { + if (remoteAdVersion == null) { + throw new VersionException(detector.getId(), "Can't forward AD task request to node running null AD version "); + } + this.detector = detector; + this.detectionDateRange = detectionDateRange; + this.user = user; + this.availableTaskSlots = availableTaskSlots; + this.adTaskAction = adTaskAction; + } + + public ForwardADTaskRequest(AnomalyDetector detector, DateRange detectionDateRange, User user, ADTaskAction adTaskAction) { + this.detector = detector; + this.detectionDateRange = detectionDateRange; + this.user = user; + this.adTaskAction = adTaskAction; + } + + public ForwardADTaskRequest(ADTask adTask, ADTaskAction adTaskAction) { + this(adTask, adTaskAction, null); + } + + public ForwardADTaskRequest(ADTask adTask, Integer availableTaskSLots, ADTaskAction adTaskAction) { + this(adTask, adTaskAction, null); + this.availableTaskSlots = availableTaskSLots; + } + + public ForwardADTaskRequest(ADTask adTask, ADTaskAction adTaskAction, List staleRunningEntities) { + this.adTask = adTask; + this.adTaskAction = adTaskAction; + if (adTask != null) { + this.detector = adTask.getDetector(); + } + this.staleRunningEntities = staleRunningEntities; + } + + public ForwardADTaskRequest(StreamInput in) throws IOException { + super(in); + this.detector = new AnomalyDetector(in); + if (in.readBoolean()) { + this.user = new User(in); + } + this.adTaskAction = in.readEnum(ADTaskAction.class); + if (in.available() == 0) { + // Old version on or before 1.0 will send less fields. + // This will reject request from old node running AD version on or before 1.0. + // So if coordinating node is old node, it can't use new node as worker node + // to run task. + throw new VersionException("Can't process ForwardADTaskRequest of old version"); + } + if (in.readBoolean()) { + this.adTask = new ADTask(in); + } + if (in.readBoolean()) { + this.detectionDateRange = new DateRange(in); + } + this.staleRunningEntities = in.readOptionalStringList(); + availableTaskSlots = in.readOptionalInt(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + detector.writeTo(out); + if (user != null) { + out.writeBoolean(true); + user.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeEnum(adTaskAction); + // From AD 1.1, only forward AD task request to nodes with same local AD version + if (adTask != null) { + out.writeBoolean(true); + adTask.writeTo(out); + } else { + out.writeBoolean(false); + } + if (detectionDateRange != null) { + out.writeBoolean(true); + detectionDateRange.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeOptionalStringCollection(staleRunningEntities); + out.writeOptionalInt(availableTaskSlots); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (detector == null) { + validationException = addValidationError(ADCommonMessages.DETECTOR_MISSING, validationException); + } else if (detector.getId() == null) { + validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + } + if (adTaskAction == null) { + validationException = addValidationError(ADCommonMessages.AD_TASK_ACTION_MISSING, validationException); + } + if (adTaskAction == ADTaskAction.CLEAN_STALE_RUNNING_ENTITIES && (staleRunningEntities == null || staleRunningEntities.isEmpty())) { + validationException = addValidationError(ADCommonMessages.EMPTY_STALE_RUNNING_ENTITIES, validationException); + } + return validationException; + } + + public AnomalyDetector getDetector() { + return detector; + } + + public ADTask getAdTask() { + return adTask; + } + + public DateRange getDetectionDateRange() { + return detectionDateRange; + } + + public User getUser() { + return user; + } + + public ADTaskAction getAdTaskAction() { + return adTaskAction; + } + + public List getStaleRunningEntities() { + return staleRunningEntities; + } + + public Integer getAvailableTaskSLots() { + return availableTaskSlots; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + ForwardADTaskRequest request = (ForwardADTaskRequest) o; + return Objects.equals(detector, request.detector) + && Objects.equals(adTask, request.adTask) + && Objects.equals(detectionDateRange, request.detectionDateRange) + && Objects.equals(staleRunningEntities, request.staleRunningEntities) + && Objects.equals(user, request.user) + && Objects.equals(availableTaskSlots, request.availableTaskSlots) + && adTaskAction == request.adTaskAction; + } + + @Override + public int hashCode() { + return Objects.hash(detector, adTask, detectionDateRange, staleRunningEntities, user, availableTaskSlots, adTaskAction); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ForwardADTaskTransportAction.java b/src/main/java/org/opensearch/ad/transport/ForwardADTaskTransportAction.java index be3b45c99..adc8e36a8 100644 --- a/src/main/java/org/opensearch/ad/transport/ForwardADTaskTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ForwardADTaskTransportAction.java @@ -34,7 +34,7 @@ import org.opensearch.ad.task.ADTaskManager; import org.opensearch.common.inject.Inject; import org.opensearch.commons.authuser.User; -import org.opensearch.rest.RestStatus; +import org.opensearch.core.rest.RestStatus; import org.opensearch.tasks.Task; import org.opensearch.timeseries.model.DateRange; import org.opensearch.transport.TransportService; diff --git a/src/main/java/org/opensearch/ad/transport/ForwardADTaskTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/ForwardADTaskTransportAction.java-e new file mode 100644 index 000000000..adc8e36a8 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ForwardADTaskTransportAction.java-e @@ -0,0 +1,260 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.model.ADTask.ERROR_FIELD; +import static org.opensearch.ad.model.ADTask.STATE_FIELD; +import static org.opensearch.ad.model.ADTask.TASK_PROGRESS_FIELD; + +import java.util.Arrays; +import java.util.List; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskAction; +import org.opensearch.ad.model.ADTaskState; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.task.ADTaskCacheManager; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.common.inject.Inject; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableMap; + +public class ForwardADTaskTransportAction extends HandledTransportAction { + private final Logger logger = LogManager.getLogger(ForwardADTaskTransportAction.class); + private final TransportService transportService; + private final ADTaskManager adTaskManager; + private final ADTaskCacheManager adTaskCacheManager; + + // ========================================================= + // Fields below contains cache for realtime AD on coordinating + // node. We need to clean up these caches when receive FINISHED + // action for realtime task. + // ========================================================= + // NodeStateManager caches anomaly detector's backpressure counter for realtime detection. + private final NodeStateManager stateManager; + // FeatureManager caches anomaly detector's feature data points for shingling of realtime detection. + private final FeatureManager featureManager; + + @Inject + public ForwardADTaskTransportAction( + ActionFilters actionFilters, + TransportService transportService, + ADTaskManager adTaskManager, + ADTaskCacheManager adTaskCacheManager, + FeatureManager featureManager, + NodeStateManager stateManager + ) { + super(ForwardADTaskAction.NAME, transportService, actionFilters, ForwardADTaskRequest::new); + this.adTaskManager = adTaskManager; + this.transportService = transportService; + this.adTaskCacheManager = adTaskCacheManager; + this.featureManager = featureManager; + this.stateManager = stateManager; + } + + @Override + protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener listener) { + ADTaskAction adTaskAction = request.getAdTaskAction(); + AnomalyDetector detector = request.getDetector(); + DateRange detectionDateRange = request.getDetectionDateRange(); + String detectorId = detector.getId(); + ADTask adTask = request.getAdTask(); + User user = request.getUser(); + Integer availableTaskSlots = request.getAvailableTaskSLots(); + + String entityValue = adTaskManager.convertEntityToString(adTask); + + switch (adTaskAction) { + case APPLY_FOR_TASK_SLOTS: + logger.debug("Received APPLY_FOR_TASK_SLOTS action for detector {}", detectorId); + adTaskManager.checkTaskSlots(adTask, detector, detectionDateRange, user, ADTaskAction.START, transportService, listener); + break; + case CHECK_AVAILABLE_TASK_SLOTS: + logger.debug("Received CHECK_AVAILABLE_TASK_SLOTS action for detector {}", detectorId); + adTaskManager + .checkTaskSlots( + adTask, + detector, + detectionDateRange, + user, + ADTaskAction.SCALE_ENTITY_TASK_SLOTS, + transportService, + listener + ); + break; + case START: + // Start historical analysis for detector + logger.debug("Received START action for detector {}", detectorId); + adTaskManager.startDetector(detector, detectionDateRange, user, transportService, ActionListener.wrap(r -> { + adTaskCacheManager.setDetectorTaskSlots(detector.getId(), availableTaskSlots); + listener.onResponse(r); + }, e -> listener.onFailure(e))); + break; + case NEXT_ENTITY: + logger.debug("Received NEXT_ENTITY action for detector {}, task {}", detectorId, adTask.getTaskId()); + // Run next entity for HC detector historical analysis. + if (detector.isHighCardinality()) { // AD task could be HC detector level task or entity task + adTaskCacheManager.removeRunningEntity(detectorId, entityValue); + if (!adTaskCacheManager.hasEntity(detectorId)) { + adTaskCacheManager.setDetectorTaskSlots(detectorId, 0); + logger.info("Historical HC detector done, will remove from cache, detector id:{}", detectorId); + listener.onResponse(new AnomalyDetectorJobResponse(detectorId, 0, 0, 0, RestStatus.OK)); + ADTaskState state = !adTask.isEntityTask() && adTask.getError() != null ? ADTaskState.FAILED : ADTaskState.FINISHED; + adTaskManager.setHCDetectorTaskDone(adTask, state, listener); + } else { + logger.debug("Run next entity for detector " + detectorId); + adTaskManager.runNextEntityForHCADHistorical(adTask, transportService, listener); + adTaskManager + .updateADHCDetectorTask( + detectorId, + adTask.getParentTaskId(), + ImmutableMap + .of( + STATE_FIELD, + ADTaskState.RUNNING.name(), + TASK_PROGRESS_FIELD, + adTaskManager.hcDetectorProgress(detectorId), + ERROR_FIELD, + adTask.getError() != null ? adTask.getError() : "" + ) + ); + } + } else { + logger + .warn( + "Can only handle HC entity task for NEXT_ENTITY action, taskId:{} , taskType:{}", + adTask.getTaskId(), + adTask.getTaskType() + ); + listener.onFailure(new IllegalArgumentException("Unsupported task")); + } + break; + case PUSH_BACK_ENTITY: + logger.debug("Received PUSH_BACK_ENTITY action for detector {}, task {}", detectorId, adTask.getTaskId()); + // Push back entity to pending entities queue and run next entity. + if (adTask.isEntityTask()) { // AD task must be entity level task. + adTaskCacheManager.removeRunningEntity(detectorId, entityValue); + if (adTaskManager.isRetryableError(adTask.getError()) + && !adTaskCacheManager.exceedRetryLimit(adTask.getId(), adTask.getTaskId())) { + // If retryable exception happens when run entity task, will push back entity to the end + // of pending entities queue, then we can retry it later. + adTaskCacheManager.pushBackEntity(adTask.getTaskId(), adTask.getId(), entityValue); + } else { + // If exception is not retryable or exceeds retry limit, will remove this entity. + adTaskCacheManager.removeEntity(adTask.getId(), entityValue); + logger.warn("Entity task failed, task id: {}, entity: {}", adTask.getTaskId(), adTask.getEntity().toString()); + } + if (!adTaskCacheManager.hasEntity(detectorId)) { + adTaskCacheManager.setDetectorTaskSlots(detectorId, 0); + adTaskManager.setHCDetectorTaskDone(adTask, ADTaskState.FINISHED, listener); + } else { + logger.debug("scale task slots for PUSH_BACK_ENTITY, detector {} task {}", detectorId, adTask.getTaskId()); + int taskSlots = adTaskCacheManager.scaleDownHCDetectorTaskSlots(detectorId, 1); + if (taskSlots == 1) { + logger.debug("After scale down, only 1 task slot reserved for detector {}, run next entity", detectorId); + adTaskManager.runNextEntityForHCADHistorical(adTask, transportService, listener); + } + listener.onResponse(new AnomalyDetectorJobResponse(adTask.getTaskId(), 0, 0, 0, RestStatus.ACCEPTED)); + } + } else { + logger.warn("Can only push back entity task"); + listener.onFailure(new IllegalArgumentException("Can only push back entity task")); + } + break; + case SCALE_ENTITY_TASK_SLOTS: + logger.debug("Received SCALE_ENTITY_TASK_LANE action for detector {}", detectorId); + // Check current available task slots and scale entity task lane. + if (availableTaskSlots != null && availableTaskSlots > 0) { + int newSlots = Math.min(availableTaskSlots, adTaskManager.detectorTaskSlotScaleDelta(detectorId)); + if (newSlots > 0) { + adTaskCacheManager.setAllowedRunningEntities(detectorId, newSlots); + adTaskCacheManager.scaleUpDetectorTaskSlots(detectorId, newSlots); + } + } + listener.onResponse(new AnomalyDetectorJobResponse(detector.getId(), 0, 0, 0, RestStatus.OK)); + break; + case CANCEL: + logger.debug("Received CANCEL action for detector {}", detectorId); + // Cancel HC detector's historical analysis. + // Don't support single detector for this action as single entity task will be cancelled directly + // on worker node. + if (detector.isHighCardinality()) { + adTaskCacheManager.clearPendingEntities(detectorId); + adTaskCacheManager.removeRunningEntity(detectorId, entityValue); + if (!adTaskCacheManager.hasEntity(detectorId) || !adTask.isEntityTask()) { + adTaskManager.setHCDetectorTaskDone(adTask, ADTaskState.STOPPED, listener); + } + listener.onResponse(new AnomalyDetectorJobResponse(adTask.getTaskId(), 0, 0, 0, RestStatus.OK)); + } else { + listener.onFailure(new IllegalArgumentException("Only support cancel HC now")); + } + break; + case CLEAN_STALE_RUNNING_ENTITIES: + logger.debug("Received CLEAN_STALE_RUNNING_ENTITIES action for detector {}", detectorId); + // Clean stale running entities of HC detector. For example, some worker node crashed or failed to send + // entity task done message to coordinating node, then coordinating node can't remove running entity + // from cache. We will check task profile when get task. If some entities exist in coordinating cache but + // doesn't exist in worker node's cache, we will clean up these stale running entities on coordinating node. + List staleRunningEntities = request.getStaleRunningEntities(); + logger + .debug( + "Clean stale running entities of task {}, staleRunningEntities: {}", + adTask.getTaskId(), + Arrays.toString(staleRunningEntities.toArray(new String[0])) + ); + for (String entity : staleRunningEntities) { + adTaskManager.removeStaleRunningEntity(adTask, entity, transportService, listener); + } + listener.onResponse(new AnomalyDetectorJobResponse(adTask.getTaskId(), 0, 0, 0, RestStatus.OK)); + break; + case CLEAN_CACHE: + boolean historicalTask = adTask.isHistoricalTask(); + logger + .debug( + "Received CLEAN_CACHE action for detector {}, taskId: {}, historical: {}", + detectorId, + adTask.getTaskId(), + historicalTask + ); + if (historicalTask) { + // Don't clear task cache if still has running entity. CLEAN_STALE_RUNNING_ENTITIES will clean + // stale running entity. + adTaskCacheManager.removeHistoricalTaskCacheIfNoRunningEntity(detectorId); + } else { + adTaskCacheManager.removeRealtimeTaskCache(detectorId); + // If hash ring changed like new node added when scale out, the realtime job coordinating node may + // change, then we should clean up cache on old coordinating node. + stateManager.clear(detectorId); + featureManager.clear(detectorId); + } + listener.onResponse(new AnomalyDetectorJobResponse(detector.getId(), 0, 0, 0, RestStatus.OK)); + break; + default: + listener.onFailure(new OpenSearchStatusException("Unsupported AD task action " + adTaskAction, RestStatus.BAD_REQUEST)); + break; + } + + } +} diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorAction.java-e b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorAction.java-e new file mode 100644 index 000000000..c4232047d --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorAction.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; + +public class GetAnomalyDetectorAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detectors/get"; + public static final GetAnomalyDetectorAction INSTANCE = new GetAnomalyDetectorAction(); + + private GetAnomalyDetectorAction() { + super(NAME, GetAnomalyDetectorResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java index a0ed18941..aef29626d 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java @@ -15,8 +15,8 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.timeseries.model.Entity; public class GetAnomalyDetectorRequest extends ActionRequest { diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java-e b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java-e new file mode 100644 index 000000000..aef29626d --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java-e @@ -0,0 +1,122 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.timeseries.model.Entity; + +public class GetAnomalyDetectorRequest extends ActionRequest { + + private String detectorID; + private long version; + private boolean returnJob; + private boolean returnTask; + private String typeStr; + private String rawPath; + private boolean all; + private Entity entity; + + public GetAnomalyDetectorRequest(StreamInput in) throws IOException { + super(in); + detectorID = in.readString(); + version = in.readLong(); + returnJob = in.readBoolean(); + returnTask = in.readBoolean(); + typeStr = in.readString(); + rawPath = in.readString(); + all = in.readBoolean(); + if (in.readBoolean()) { + entity = new Entity(in); + } + } + + public GetAnomalyDetectorRequest( + String detectorID, + long version, + boolean returnJob, + boolean returnTask, + String typeStr, + String rawPath, + boolean all, + Entity entity + ) { + super(); + this.detectorID = detectorID; + this.version = version; + this.returnJob = returnJob; + this.returnTask = returnTask; + this.typeStr = typeStr; + this.rawPath = rawPath; + this.all = all; + this.entity = entity; + } + + public String getDetectorID() { + return detectorID; + } + + public long getVersion() { + return version; + } + + public boolean isReturnJob() { + return returnJob; + } + + public boolean isReturnTask() { + return returnTask; + } + + public String getTypeStr() { + return typeStr; + } + + public String getRawPath() { + return rawPath; + } + + public boolean isAll() { + return all; + } + + public Entity getEntity() { + return entity; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(detectorID); + out.writeLong(version); + out.writeBoolean(returnJob); + out.writeBoolean(returnTask); + out.writeString(typeStr); + out.writeString(rawPath); + out.writeBoolean(all); + if (this.entity != null) { + out.writeBoolean(true); + entity.writeTo(out); + } else { + out.writeBoolean(false); + } + } + + @Override + public ActionRequestValidationException validate() { + return null; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java index a2cd47e6a..e1532b816 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java @@ -19,11 +19,11 @@ import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.DetectorProfile; import org.opensearch.ad.model.EntityProfile; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.rest.RestStatus; import org.opensearch.timeseries.util.RestHandlerUtils; public class GetAnomalyDetectorResponse extends ActionResponse implements ToXContentObject { diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java-e b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java-e new file mode 100644 index 000000000..a1a2b96cc --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java-e @@ -0,0 +1,215 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionResponse; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.model.DetectorProfile; +import org.opensearch.ad.model.EntityProfile; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.timeseries.util.RestHandlerUtils; + +public class GetAnomalyDetectorResponse extends ActionResponse implements ToXContentObject { + public static final String DETECTOR_PROFILE = "detectorProfile"; + public static final String ENTITY_PROFILE = "entityProfile"; + private long version; + private String id; + private long primaryTerm; + private long seqNo; + private AnomalyDetector detector; + private AnomalyDetectorJob adJob; + private ADTask realtimeAdTask; + private ADTask historicalAdTask; + private RestStatus restStatus; + private DetectorProfile detectorProfile; + private EntityProfile entityProfile; + private boolean profileResponse; + private boolean returnJob; + private boolean returnTask; + + public GetAnomalyDetectorResponse(StreamInput in) throws IOException { + super(in); + profileResponse = in.readBoolean(); + if (profileResponse) { + String profileType = in.readString(); + if (DETECTOR_PROFILE.equals(profileType)) { + detectorProfile = new DetectorProfile(in); + } else { + entityProfile = new EntityProfile(in); + } + + } else { + detectorProfile = null; + id = in.readString(); + version = in.readLong(); + primaryTerm = in.readLong(); + seqNo = in.readLong(); + restStatus = in.readEnum(RestStatus.class); + detector = new AnomalyDetector(in); + returnJob = in.readBoolean(); + if (returnJob) { + adJob = new AnomalyDetectorJob(in); + } else { + adJob = null; + } + returnTask = in.readBoolean(); + if (in.readBoolean()) { + realtimeAdTask = new ADTask(in); + } else { + realtimeAdTask = null; + } + if (in.readBoolean()) { + historicalAdTask = new ADTask(in); + } else { + historicalAdTask = null; + } + } + } + + public GetAnomalyDetectorResponse( + long version, + String id, + long primaryTerm, + long seqNo, + AnomalyDetector detector, + AnomalyDetectorJob adJob, + boolean returnJob, + ADTask realtimeAdTask, + ADTask historicalAdTask, + boolean returnTask, + RestStatus restStatus, + DetectorProfile detectorProfile, + EntityProfile entityProfile, + boolean profileResponse + ) { + this.version = version; + this.id = id; + this.primaryTerm = primaryTerm; + this.seqNo = seqNo; + this.detector = detector; + this.restStatus = restStatus; + this.returnJob = returnJob; + + if (this.returnJob) { + this.adJob = adJob; + } else { + this.adJob = null; + } + this.returnTask = returnTask; + if (this.returnTask) { + this.realtimeAdTask = realtimeAdTask; + this.historicalAdTask = historicalAdTask; + } else { + this.realtimeAdTask = null; + this.historicalAdTask = null; + } + this.detectorProfile = detectorProfile; + this.entityProfile = entityProfile; + this.profileResponse = profileResponse; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (profileResponse) { + out.writeBoolean(true); // profileResponse is true + if (detectorProfile != null) { + out.writeString(DETECTOR_PROFILE); + detectorProfile.writeTo(out); + } else if (entityProfile != null) { + out.writeString(ENTITY_PROFILE); + entityProfile.writeTo(out); + } + } else { + out.writeBoolean(false); // profileResponse is false + out.writeString(id); + out.writeLong(version); + out.writeLong(primaryTerm); + out.writeLong(seqNo); + out.writeEnum(restStatus); + detector.writeTo(out); + if (returnJob) { + out.writeBoolean(true); // returnJob is true + adJob.writeTo(out); + } else { + out.writeBoolean(false); // returnJob is false + } + out.writeBoolean(returnTask); + if (realtimeAdTask != null) { + out.writeBoolean(true); + realtimeAdTask.writeTo(out); + } else { + out.writeBoolean(false); + } + if (historicalAdTask != null) { + out.writeBoolean(true); + historicalAdTask.writeTo(out); + } else { + out.writeBoolean(false); + } + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (profileResponse) { + if (detectorProfile != null) { + detectorProfile.toXContent(builder, params); + } else { + entityProfile.toXContent(builder, params); + } + } else { + builder.startObject(); + builder.field(RestHandlerUtils._ID, id); + builder.field(RestHandlerUtils._VERSION, version); + builder.field(RestHandlerUtils._PRIMARY_TERM, primaryTerm); + builder.field(RestHandlerUtils._SEQ_NO, seqNo); + builder.field(RestHandlerUtils.ANOMALY_DETECTOR, detector); + if (returnJob) { + builder.field(RestHandlerUtils.ANOMALY_DETECTOR_JOB, adJob); + } + if (returnTask) { + builder.field(RestHandlerUtils.REALTIME_TASK, realtimeAdTask); + builder.field(RestHandlerUtils.HISTORICAL_ANALYSIS_TASK, historicalAdTask); + } + builder.endObject(); + } + return builder; + } + + public DetectorProfile getDetectorProfile() { + return detectorProfile; + } + + public AnomalyDetectorJob getAdJob() { + return adJob; + } + + public ADTask getRealtimeAdTask() { + return realtimeAdTask; + } + + public ADTask getHistoricalAdTask() { + return historicalAdTask; + } + + public AnomalyDetector getDetector() { + return detector; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java index 28ee7278c..473f247dd 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java @@ -14,7 +14,7 @@ import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_GET_DETECTOR; import static org.opensearch.ad.model.ADTaskType.ALL_DETECTOR_TASK_TYPES; import static org.opensearch.ad.settings.AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; import static org.opensearch.timeseries.util.ParseUtils.getUserContext; import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; @@ -61,9 +61,9 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.common.Strings; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.rest.RestStatus; import org.opensearch.tasks.Task; import org.opensearch.timeseries.Name; import org.opensearch.timeseries.constant.CommonName; diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java-e new file mode 100644 index 000000000..5bdfabda6 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java-e @@ -0,0 +1,432 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_GET_DETECTOR; +import static org.opensearch.ad.model.ADTaskType.ALL_DETECTOR_TASK_TYPES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; +import static org.opensearch.timeseries.util.ParseUtils.getUserContext; +import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; +import static org.opensearch.timeseries.util.RestHandlerUtils.PROFILE; +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.MultiGetItemResponse; +import org.opensearch.action.get.MultiGetRequest; +import org.opensearch.action.get.MultiGetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.AnomalyDetectorProfileRunner; +import org.opensearch.ad.EntityProfileRunner; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.model.DetectorProfile; +import org.opensearch.ad.model.DetectorProfileName; +import org.opensearch.ad.model.EntityProfileName; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.CheckedConsumer; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.Name; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.Sets; + +public class GetAnomalyDetectorTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(GetAnomalyDetectorTransportAction.class); + + private final ClusterService clusterService; + private final Client client; + private final SecurityClientUtil clientUtil; + private final Set allProfileTypeStrs; + private final Set allProfileTypes; + private final Set defaultDetectorProfileTypes; + private final Set allEntityProfileTypeStrs; + private final Set allEntityProfileTypes; + private final Set defaultEntityProfileTypes; + private final NamedXContentRegistry xContentRegistry; + private final DiscoveryNodeFilterer nodeFilter; + private final TransportService transportService; + private volatile Boolean filterByEnabled; + private final ADTaskManager adTaskManager; + + @Inject + public GetAnomalyDetectorTransportAction( + TransportService transportService, + DiscoveryNodeFilterer nodeFilter, + ActionFilters actionFilters, + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + Settings settings, + NamedXContentRegistry xContentRegistry, + ADTaskManager adTaskManager + ) { + super(GetAnomalyDetectorAction.NAME, transportService, actionFilters, GetAnomalyDetectorRequest::new); + this.clusterService = clusterService; + this.client = client; + this.clientUtil = clientUtil; + List allProfiles = Arrays.asList(DetectorProfileName.values()); + this.allProfileTypes = EnumSet.copyOf(allProfiles); + this.allProfileTypeStrs = getProfileListStrs(allProfiles); + List defaultProfiles = Arrays.asList(DetectorProfileName.ERROR, DetectorProfileName.STATE); + this.defaultDetectorProfileTypes = new HashSet(defaultProfiles); + + List allEntityProfiles = Arrays.asList(EntityProfileName.values()); + this.allEntityProfileTypes = EnumSet.copyOf(allEntityProfiles); + this.allEntityProfileTypeStrs = getProfileListStrs(allEntityProfiles); + List defaultEntityProfiles = Arrays.asList(EntityProfileName.STATE); + this.defaultEntityProfileTypes = new HashSet(defaultEntityProfiles); + + this.xContentRegistry = xContentRegistry; + this.nodeFilter = nodeFilter; + filterByEnabled = AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); + this.transportService = transportService; + this.adTaskManager = adTaskManager; + } + + @Override + protected void doExecute(Task task, GetAnomalyDetectorRequest request, ActionListener actionListener) { + String detectorID = request.getDetectorID(); + User user = getUserContext(client); + ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_GET_DETECTOR); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + resolveUserAndExecute( + user, + detectorID, + filterByEnabled, + listener, + (anomalyDetector) -> getExecute(request, listener), + client, + clusterService, + xContentRegistry + ); + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + } + + protected void getExecute(GetAnomalyDetectorRequest request, ActionListener listener) { + String detectorID = request.getDetectorID(); + String typesStr = request.getTypeStr(); + String rawPath = request.getRawPath(); + Entity entity = request.getEntity(); + boolean all = request.isAll(); + boolean returnJob = request.isReturnJob(); + boolean returnTask = request.isReturnTask(); + + try { + if (!Strings.isEmpty(typesStr) || rawPath.endsWith(PROFILE) || rawPath.endsWith(PROFILE + "/")) { + if (entity != null) { + Set entityProfilesToCollect = getEntityProfilesToCollect(typesStr, all); + EntityProfileRunner profileRunner = new EntityProfileRunner( + client, + clientUtil, + xContentRegistry, + AnomalyDetectorSettings.NUM_MIN_SAMPLES + ); + profileRunner + .profile( + detectorID, + entity, + entityProfilesToCollect, + ActionListener + .wrap( + profile -> { + listener + .onResponse( + new GetAnomalyDetectorResponse( + 0, + null, + 0, + 0, + null, + null, + false, + null, + null, + false, + null, + null, + profile, + true + ) + ); + }, + e -> listener.onFailure(e) + ) + ); + } else { + Set profilesToCollect = getProfilesToCollect(typesStr, all); + AnomalyDetectorProfileRunner profileRunner = new AnomalyDetectorProfileRunner( + client, + clientUtil, + xContentRegistry, + nodeFilter, + AnomalyDetectorSettings.NUM_MIN_SAMPLES, + transportService, + adTaskManager + ); + profileRunner.profile(detectorID, getProfileActionListener(listener), profilesToCollect); + } + } else { + if (returnTask) { + adTaskManager.getAndExecuteOnLatestADTasks(detectorID, null, null, ALL_DETECTOR_TASK_TYPES, (taskList) -> { + Optional realtimeAdTask = Optional.empty(); + Optional historicalAdTask = Optional.empty(); + + if (taskList != null && taskList.size() > 0) { + Map adTasks = new HashMap<>(); + List duplicateAdTasks = new ArrayList<>(); + for (ADTask task : taskList) { + if (adTasks.containsKey(task.getTaskType())) { + LOG + .info( + "Found duplicate latest task of detector {}, task id: {}, task type: {}", + detectorID, + task.getTaskType(), + task.getTaskId() + ); + duplicateAdTasks.add(task); + continue; + } + adTasks.put(task.getTaskType(), task); + } + if (duplicateAdTasks.size() > 0) { + adTaskManager.resetLatestFlagAsFalse(duplicateAdTasks); + } + + if (adTasks.containsKey(ADTaskType.REALTIME_HC_DETECTOR.name())) { + realtimeAdTask = Optional.ofNullable(adTasks.get(ADTaskType.REALTIME_HC_DETECTOR.name())); + } else if (adTasks.containsKey(ADTaskType.REALTIME_SINGLE_ENTITY.name())) { + realtimeAdTask = Optional.ofNullable(adTasks.get(ADTaskType.REALTIME_SINGLE_ENTITY.name())); + } + if (adTasks.containsKey(ADTaskType.HISTORICAL_HC_DETECTOR.name())) { + historicalAdTask = Optional.ofNullable(adTasks.get(ADTaskType.HISTORICAL_HC_DETECTOR.name())); + } else if (adTasks.containsKey(ADTaskType.HISTORICAL_SINGLE_ENTITY.name())) { + historicalAdTask = Optional.ofNullable(adTasks.get(ADTaskType.HISTORICAL_SINGLE_ENTITY.name())); + } else if (adTasks.containsKey(ADTaskType.HISTORICAL.name())) { + historicalAdTask = Optional.ofNullable(adTasks.get(ADTaskType.HISTORICAL.name())); + } + } + getDetectorAndJob(detectorID, returnJob, returnTask, realtimeAdTask, historicalAdTask, listener); + }, transportService, true, 2, listener); + } else { + getDetectorAndJob(detectorID, returnJob, returnTask, Optional.empty(), Optional.empty(), listener); + } + } + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + } + + private void getDetectorAndJob( + String detectorID, + boolean returnJob, + boolean returnTask, + Optional realtimeAdTask, + Optional historicalAdTask, + ActionListener listener + ) { + MultiGetRequest.Item adItem = new MultiGetRequest.Item(CommonName.CONFIG_INDEX, detectorID); + MultiGetRequest multiGetRequest = new MultiGetRequest().add(adItem); + if (returnJob) { + MultiGetRequest.Item adJobItem = new MultiGetRequest.Item(CommonName.JOB_INDEX, detectorID); + multiGetRequest.add(adJobItem); + } + client.multiGet(multiGetRequest, onMultiGetResponse(listener, returnJob, returnTask, realtimeAdTask, historicalAdTask, detectorID)); + } + + private ActionListener onMultiGetResponse( + ActionListener listener, + boolean returnJob, + boolean returnTask, + Optional realtimeAdTask, + Optional historicalAdTask, + String detectorId + ) { + return new ActionListener() { + @Override + public void onResponse(MultiGetResponse multiGetResponse) { + MultiGetItemResponse[] responses = multiGetResponse.getResponses(); + AnomalyDetector detector = null; + AnomalyDetectorJob adJob = null; + String id = null; + long version = 0; + long seqNo = 0; + long primaryTerm = 0; + + for (MultiGetItemResponse response : responses) { + if (CommonName.CONFIG_INDEX.equals(response.getIndex())) { + if (response.getResponse() == null || !response.getResponse().isExists()) { + listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, RestStatus.NOT_FOUND)); + return; + } + id = response.getId(); + version = response.getResponse().getVersion(); + primaryTerm = response.getResponse().getPrimaryTerm(); + seqNo = response.getResponse().getSeqNo(); + if (!response.getResponse().isSourceEmpty()) { + try ( + XContentParser parser = RestHandlerUtils + .createXContentParserFromRegistry(xContentRegistry, response.getResponse().getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + detector = parser.namedObject(AnomalyDetector.class, AnomalyDetector.PARSE_FIELD_NAME, null); + } catch (Exception e) { + String message = "Failed to parse detector job " + detectorId; + listener.onFailure(buildInternalServerErrorResponse(e, message)); + return; + } + } + } + + if (CommonName.JOB_INDEX.equals(response.getIndex())) { + if (response.getResponse() != null + && response.getResponse().isExists() + && !response.getResponse().isSourceEmpty()) { + try ( + XContentParser parser = RestHandlerUtils + .createXContentParserFromRegistry(xContentRegistry, response.getResponse().getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + adJob = AnomalyDetectorJob.parse(parser); + } catch (Exception e) { + String message = "Failed to parse detector job " + detectorId; + listener.onFailure(buildInternalServerErrorResponse(e, message)); + return; + } + } + } + } + listener + .onResponse( + new GetAnomalyDetectorResponse( + version, + id, + primaryTerm, + seqNo, + detector, + adJob, + returnJob, + realtimeAdTask.orElse(null), + historicalAdTask.orElse(null), + returnTask, + RestStatus.OK, + null, + null, + false + ) + ); + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }; + } + + private ActionListener getProfileActionListener(ActionListener listener) { + return ActionListener.wrap(new CheckedConsumer() { + @Override + public void accept(DetectorProfile profile) throws Exception { + listener + .onResponse( + new GetAnomalyDetectorResponse(0, null, 0, 0, null, null, false, null, null, false, null, profile, null, true) + ); + } + }, exception -> { listener.onFailure(exception); }); + } + + private OpenSearchStatusException buildInternalServerErrorResponse(Exception e, String errorMsg) { + LOG.error(errorMsg, e); + return new OpenSearchStatusException(errorMsg, RestStatus.INTERNAL_SERVER_ERROR); + } + + /** + * + * @param typesStr a list of input profile types separated by comma + * @param all whether we should return all profile in the response + * @return profiles to collect for a detector + */ + private Set getProfilesToCollect(String typesStr, boolean all) { + if (all) { + return this.allProfileTypes; + } else if (Strings.isEmpty(typesStr)) { + return this.defaultDetectorProfileTypes; + } else { + // Filter out unsupported types + Set typesInRequest = new HashSet<>(Arrays.asList(typesStr.split(","))); + return DetectorProfileName.getNames(Sets.intersection(allProfileTypeStrs, typesInRequest)); + } + } + + /** + * + * @param typesStr a list of input profile types separated by comma + * @param all whether we should return all profile in the response + * @return profiles to collect for an entity + */ + private Set getEntityProfilesToCollect(String typesStr, boolean all) { + if (all) { + return this.allEntityProfileTypes; + } else if (Strings.isEmpty(typesStr)) { + return this.defaultEntityProfileTypes; + } else { + // Filter out unsupported types + Set typesInRequest = new HashSet<>(Arrays.asList(typesStr.split(","))); + return EntityProfileName.getNames(Sets.intersection(allEntityProfileTypeStrs, typesInRequest)); + } + } + + private Set getProfileListStrs(List profileList) { + return profileList.stream().map(profile -> profile.getName()).collect(Collectors.toSet()); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorAction.java-e b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorAction.java-e new file mode 100644 index 000000000..9ee038336 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorAction.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; + +public class IndexAnomalyDetectorAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/write"; + public static final IndexAnomalyDetectorAction INSTANCE = new IndexAnomalyDetectorAction(); + + private IndexAnomalyDetectorAction() { + super(NAME, IndexAnomalyDetectorResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorRequest.java b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorRequest.java index 83da583e7..572e847f9 100644 --- a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorRequest.java +++ b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorRequest.java @@ -17,9 +17,9 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.action.support.WriteRequest; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.rest.RestRequest; public class IndexAnomalyDetectorRequest extends ActionRequest { diff --git a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorRequest.java-e b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorRequest.java-e new file mode 100644 index 000000000..ef44e051f --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorRequest.java-e @@ -0,0 +1,136 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.rest.RestRequest; + +public class IndexAnomalyDetectorRequest extends ActionRequest { + + private String detectorID; + private long seqNo; + private long primaryTerm; + private WriteRequest.RefreshPolicy refreshPolicy; + private AnomalyDetector detector; + private RestRequest.Method method; + private TimeValue requestTimeout; + private Integer maxSingleEntityAnomalyDetectors; + private Integer maxMultiEntityAnomalyDetectors; + private Integer maxAnomalyFeatures; + + public IndexAnomalyDetectorRequest(StreamInput in) throws IOException { + super(in); + detectorID = in.readString(); + seqNo = in.readLong(); + primaryTerm = in.readLong(); + refreshPolicy = in.readEnum(WriteRequest.RefreshPolicy.class); + detector = new AnomalyDetector(in); + method = in.readEnum(RestRequest.Method.class); + requestTimeout = in.readTimeValue(); + maxSingleEntityAnomalyDetectors = in.readInt(); + maxMultiEntityAnomalyDetectors = in.readInt(); + maxAnomalyFeatures = in.readInt(); + } + + public IndexAnomalyDetectorRequest( + String detectorID, + long seqNo, + long primaryTerm, + WriteRequest.RefreshPolicy refreshPolicy, + AnomalyDetector detector, + RestRequest.Method method, + TimeValue requestTimeout, + Integer maxSingleEntityAnomalyDetectors, + Integer maxMultiEntityAnomalyDetectors, + Integer maxAnomalyFeatures + ) { + super(); + this.detectorID = detectorID; + this.seqNo = seqNo; + this.primaryTerm = primaryTerm; + this.refreshPolicy = refreshPolicy; + this.detector = detector; + this.method = method; + this.requestTimeout = requestTimeout; + this.maxSingleEntityAnomalyDetectors = maxSingleEntityAnomalyDetectors; + this.maxMultiEntityAnomalyDetectors = maxMultiEntityAnomalyDetectors; + this.maxAnomalyFeatures = maxAnomalyFeatures; + } + + public String getDetectorID() { + return detectorID; + } + + public long getSeqNo() { + return seqNo; + } + + public long getPrimaryTerm() { + return primaryTerm; + } + + public WriteRequest.RefreshPolicy getRefreshPolicy() { + return refreshPolicy; + } + + public AnomalyDetector getDetector() { + return detector; + } + + public RestRequest.Method getMethod() { + return method; + } + + public TimeValue getRequestTimeout() { + return requestTimeout; + } + + public Integer getMaxSingleEntityAnomalyDetectors() { + return maxSingleEntityAnomalyDetectors; + } + + public Integer getMaxMultiEntityAnomalyDetectors() { + return maxMultiEntityAnomalyDetectors; + } + + public Integer getMaxAnomalyFeatures() { + return maxAnomalyFeatures; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(detectorID); + out.writeLong(seqNo); + out.writeLong(primaryTerm); + out.writeEnum(refreshPolicy); + detector.writeTo(out); + out.writeEnum(method); + out.writeTimeValue(requestTimeout); + out.writeInt(maxSingleEntityAnomalyDetectors); + out.writeInt(maxMultiEntityAnomalyDetectors); + out.writeInt(maxAnomalyFeatures); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorResponse.java b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorResponse.java index 39b014ff8..a6ae3845c 100644 --- a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorResponse.java +++ b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorResponse.java @@ -15,11 +15,11 @@ import org.opensearch.action.ActionResponse; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.rest.RestStatus; import org.opensearch.timeseries.util.RestHandlerUtils; public class IndexAnomalyDetectorResponse extends ActionResponse implements ToXContentObject { diff --git a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorResponse.java-e b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorResponse.java-e new file mode 100644 index 000000000..cc545f377 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorResponse.java-e @@ -0,0 +1,84 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionResponse; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.timeseries.util.RestHandlerUtils; + +public class IndexAnomalyDetectorResponse extends ActionResponse implements ToXContentObject { + private final String id; + private final long version; + private final long seqNo; + private final long primaryTerm; + private final AnomalyDetector detector; + private final RestStatus restStatus; + + public IndexAnomalyDetectorResponse(StreamInput in) throws IOException { + super(in); + id = in.readString(); + version = in.readLong(); + seqNo = in.readLong(); + primaryTerm = in.readLong(); + detector = new AnomalyDetector(in); + restStatus = in.readEnum(RestStatus.class); + } + + public IndexAnomalyDetectorResponse( + String id, + long version, + long seqNo, + long primaryTerm, + AnomalyDetector detector, + RestStatus restStatus + ) { + this.id = id; + this.version = version; + this.seqNo = seqNo; + this.primaryTerm = primaryTerm; + this.detector = detector; + this.restStatus = restStatus; + } + + public String getId() { + return id; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + out.writeLong(version); + out.writeLong(seqNo); + out.writeLong(primaryTerm); + detector.writeTo(out); + out.writeEnum(restStatus); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder + .startObject() + .field(RestHandlerUtils._ID, id) + .field(RestHandlerUtils._VERSION, version) + .field(RestHandlerUtils._SEQ_NO, seqNo) + .field(RestHandlerUtils.ANOMALY_DETECTOR, detector) + .field(RestHandlerUtils._PRIMARY_TERM, primaryTerm) + .endObject(); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java-e new file mode 100644 index 000000000..06018ae6c --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java-e @@ -0,0 +1,205 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_CREATE_DETECTOR; +import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_UPDATE_DETECTOR; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES; +import static org.opensearch.timeseries.util.ParseUtils.checkFilterByBackendRoles; +import static org.opensearch.timeseries.util.ParseUtils.getDetector; +import static org.opensearch.timeseries.util.ParseUtils.getUserContext; +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import java.util.List; +import java.util.function.Consumer; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.rest.handler.IndexAnomalyDetectorActionHandler; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.rest.RestRequest; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.transport.TransportService; + +public class IndexAnomalyDetectorTransportAction extends HandledTransportAction { + private static final Logger LOG = LogManager.getLogger(IndexAnomalyDetectorTransportAction.class); + private final Client client; + private final SecurityClientUtil clientUtil; + private final TransportService transportService; + private final ADIndexManagement anomalyDetectionIndices; + private final ClusterService clusterService; + private final NamedXContentRegistry xContentRegistry; + private final ADTaskManager adTaskManager; + private volatile Boolean filterByEnabled; + private final SearchFeatureDao searchFeatureDao; + private final Settings settings; + + @Inject + public IndexAnomalyDetectorTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + SecurityClientUtil clientUtil, + ClusterService clusterService, + Settings settings, + ADIndexManagement anomalyDetectionIndices, + NamedXContentRegistry xContentRegistry, + ADTaskManager adTaskManager, + SearchFeatureDao searchFeatureDao + ) { + super(IndexAnomalyDetectorAction.NAME, transportService, actionFilters, IndexAnomalyDetectorRequest::new); + this.client = client; + this.clientUtil = clientUtil; + this.transportService = transportService; + this.clusterService = clusterService; + this.anomalyDetectionIndices = anomalyDetectionIndices; + this.xContentRegistry = xContentRegistry; + this.adTaskManager = adTaskManager; + this.searchFeatureDao = searchFeatureDao; + filterByEnabled = AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); + this.settings = settings; + } + + @Override + protected void doExecute(Task task, IndexAnomalyDetectorRequest request, ActionListener actionListener) { + User user = getUserContext(client); + String detectorId = request.getDetectorID(); + RestRequest.Method method = request.getMethod(); + String errorMessage = method == RestRequest.Method.PUT ? FAIL_TO_UPDATE_DETECTOR : FAIL_TO_CREATE_DETECTOR; + ActionListener listener = wrapRestActionListener(actionListener, errorMessage); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + resolveUserAndExecute(user, detectorId, method, listener, (detector) -> adExecute(request, user, detector, context, listener)); + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + } + + private void resolveUserAndExecute( + User requestedUser, + String detectorId, + RestRequest.Method method, + ActionListener listener, + Consumer function + ) { + try { + // Check if user has backend roles + // When filter by is enabled, block users creating/updating detectors who do not have backend roles. + if (filterByEnabled && !checkFilterByBackendRoles(requestedUser, listener)) { + return; + } + if (method == RestRequest.Method.PUT) { + // requestedUser == null means security is disabled or user is superadmin. In this case we don't need to + // check if request user have access to the detector or not. But we still need to get current detector for + // this case, so we can keep current detector's user data. + boolean filterByBackendRole = requestedUser == null ? false : filterByEnabled; + // Update detector request, check if user has permissions to update the detector + // Get detector and verify backend roles + getDetector(requestedUser, detectorId, listener, function, client, clusterService, xContentRegistry, filterByBackendRole); + } else { + // Create Detector. No need to get current detector. + function.accept(null); + } + } catch (Exception e) { + listener.onFailure(e); + } + } + + protected void adExecute( + IndexAnomalyDetectorRequest request, + User user, + AnomalyDetector currentDetector, + ThreadContext.StoredContext storedContext, + ActionListener listener + ) { + anomalyDetectionIndices.update(); + String detectorId = request.getDetectorID(); + long seqNo = request.getSeqNo(); + long primaryTerm = request.getPrimaryTerm(); + WriteRequest.RefreshPolicy refreshPolicy = request.getRefreshPolicy(); + AnomalyDetector detector = request.getDetector(); + RestRequest.Method method = request.getMethod(); + TimeValue requestTimeout = request.getRequestTimeout(); + Integer maxSingleEntityAnomalyDetectors = request.getMaxSingleEntityAnomalyDetectors(); + Integer maxMultiEntityAnomalyDetectors = request.getMaxMultiEntityAnomalyDetectors(); + Integer maxAnomalyFeatures = request.getMaxAnomalyFeatures(); + + storedContext.restore(); + checkIndicesAndExecute(detector.getIndices(), () -> { + // Don't replace detector's user when update detector + // Github issue: https://github.com/opensearch-project/anomaly-detection/issues/124 + User detectorUser = currentDetector == null ? user : currentDetector.getUser(); + IndexAnomalyDetectorActionHandler indexAnomalyDetectorActionHandler = new IndexAnomalyDetectorActionHandler( + clusterService, + client, + clientUtil, + transportService, + listener, + anomalyDetectionIndices, + detectorId, + seqNo, + primaryTerm, + refreshPolicy, + detector, + requestTimeout, + maxSingleEntityAnomalyDetectors, + maxMultiEntityAnomalyDetectors, + maxAnomalyFeatures, + method, + xContentRegistry, + detectorUser, + adTaskManager, + searchFeatureDao, + settings + ); + indexAnomalyDetectorActionHandler.start(); + }, listener); + } + + private void checkIndicesAndExecute( + List indices, + ExecutorFunction function, + ActionListener listener + ) { + SearchRequest searchRequest = new SearchRequest() + .indices(indices.toArray(new String[0])) + .source(new SearchSourceBuilder().size(1).query(QueryBuilders.matchAllQuery())); + client.search(searchRequest, ActionListener.wrap(r -> { function.execute(); }, e -> { + // Due to below issue with security plugin, we get security_exception when invalid index name is mentioned. + // https://github.com/opendistro-for-elasticsearch/security/issues/718 + LOG.error(e); + listener.onFailure(e); + })); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorAction.java-e b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorAction.java-e new file mode 100644 index 000000000..c90ecc446 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorAction.java-e @@ -0,0 +1,25 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; + +public class PreviewAnomalyDetectorAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/preview"; + public static final PreviewAnomalyDetectorAction INSTANCE = new PreviewAnomalyDetectorAction(); + + private PreviewAnomalyDetectorAction() { + super(NAME, PreviewAnomalyDetectorResponse::new); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorRequest.java b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorRequest.java index d7a097bc8..11fa848f7 100644 --- a/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorRequest.java +++ b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorRequest.java @@ -17,8 +17,8 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; public class PreviewAnomalyDetectorRequest extends ActionRequest { diff --git a/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorRequest.java-e b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorRequest.java-e new file mode 100644 index 000000000..11fa848f7 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorRequest.java-e @@ -0,0 +1,76 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.time.Instant; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class PreviewAnomalyDetectorRequest extends ActionRequest { + + private AnomalyDetector detector; + private String detectorId; + private Instant startTime; + private Instant endTime; + + public PreviewAnomalyDetectorRequest(StreamInput in) throws IOException { + super(in); + detector = new AnomalyDetector(in); + detectorId = in.readOptionalString(); + startTime = in.readInstant(); + endTime = in.readInstant(); + } + + public PreviewAnomalyDetectorRequest(AnomalyDetector detector, String detectorId, Instant startTime, Instant endTime) + throws IOException { + super(); + this.detector = detector; + this.detectorId = detectorId; + this.startTime = startTime; + this.endTime = endTime; + } + + public AnomalyDetector getDetector() { + return detector; + } + + public String getId() { + return detectorId; + } + + public Instant getStartTime() { + return startTime; + } + + public Instant getEndTime() { + return endTime; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + detector.writeTo(out); + out.writeOptionalString(detectorId); + out.writeInstant(startTime); + out.writeInstant(endTime); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorResponse.java b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorResponse.java index 664e87234..2bba63d9a 100644 --- a/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorResponse.java +++ b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorResponse.java @@ -17,8 +17,8 @@ import org.opensearch.action.ActionResponse; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorResponse.java-e b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorResponse.java-e new file mode 100644 index 000000000..2bba63d9a --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorResponse.java-e @@ -0,0 +1,52 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.action.ActionResponse; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +public class PreviewAnomalyDetectorResponse extends ActionResponse implements ToXContentObject { + public static final String ANOMALY_RESULT = "anomaly_result"; + public static final String ANOMALY_DETECTOR = "anomaly_detector"; + private List anomalyResult; + private AnomalyDetector detector; + + public PreviewAnomalyDetectorResponse(StreamInput in) throws IOException { + super(in); + anomalyResult = in.readList(AnomalyResult::new); + detector = new AnomalyDetector(in); + } + + public PreviewAnomalyDetectorResponse(List anomalyResult, AnomalyDetector detector) { + this.anomalyResult = anomalyResult; + this.detector = detector; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeList(anomalyResult); + detector.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.startObject().field(ANOMALY_RESULT, anomalyResult).field(ANOMALY_DETECTOR, detector).endObject(); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportAction.java index 4111fcb5b..5d6bdd193 100644 --- a/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportAction.java @@ -15,7 +15,7 @@ import static org.opensearch.ad.settings.AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_ANOMALY_FEATURES; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_CONCURRENT_PREVIEW; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.timeseries.util.ParseUtils.getUserContext; import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; @@ -47,9 +47,9 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.rest.RestStatus; import org.opensearch.tasks.Task; import org.opensearch.timeseries.common.exception.ClientException; import org.opensearch.timeseries.common.exception.LimitExceededException; diff --git a/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportAction.java-e new file mode 100644 index 000000000..f22096f43 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportAction.java-e @@ -0,0 +1,252 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_PREVIEW_DETECTOR; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_ANOMALY_FEATURES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_CONCURRENT_PREVIEW; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.util.ParseUtils.getUserContext; +import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import java.io.IOException; +import java.time.Instant; +import java.util.List; +import java.util.concurrent.Semaphore; + +import org.apache.commons.lang.StringUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.AnomalyDetectorRunner; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.CheckedConsumer; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.common.exception.ClientException; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +public class PreviewAnomalyDetectorTransportAction extends + HandledTransportAction { + private final Logger logger = LogManager.getLogger(PreviewAnomalyDetectorTransportAction.class); + private final AnomalyDetectorRunner anomalyDetectorRunner; + private final ClusterService clusterService; + private final Client client; + private final NamedXContentRegistry xContentRegistry; + private volatile Integer maxAnomalyFeatures; + private volatile Boolean filterByEnabled; + private final ADCircuitBreakerService adCircuitBreakerService; + private Semaphore lock; + + @Inject + public PreviewAnomalyDetectorTransportAction( + Settings settings, + TransportService transportService, + ClusterService clusterService, + ActionFilters actionFilters, + Client client, + AnomalyDetectorRunner anomalyDetectorRunner, + NamedXContentRegistry xContentRegistry, + ADCircuitBreakerService adCircuitBreakerService + ) { + super(PreviewAnomalyDetectorAction.NAME, transportService, actionFilters, PreviewAnomalyDetectorRequest::new); + this.clusterService = clusterService; + this.client = client; + this.anomalyDetectorRunner = anomalyDetectorRunner; + this.xContentRegistry = xContentRegistry; + maxAnomalyFeatures = MAX_ANOMALY_FEATURES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_ANOMALY_FEATURES, it -> maxAnomalyFeatures = it); + filterByEnabled = AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); + this.adCircuitBreakerService = adCircuitBreakerService; + this.lock = new Semaphore(MAX_CONCURRENT_PREVIEW.get(settings), true); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_CONCURRENT_PREVIEW, it -> { lock = new Semaphore(it); }); + } + + @Override + protected void doExecute( + Task task, + PreviewAnomalyDetectorRequest request, + ActionListener actionListener + ) { + String detectorId = request.getId(); + User user = getUserContext(client); + ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_PREVIEW_DETECTOR); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + resolveUserAndExecute( + user, + detectorId, + filterByEnabled, + listener, + (anomalyDetector) -> previewExecute(request, context, listener), + client, + clusterService, + xContentRegistry + ); + } catch (Exception e) { + logger.error(e); + listener.onFailure(e); + } + } + + void previewExecute( + PreviewAnomalyDetectorRequest request, + ThreadContext.StoredContext context, + ActionListener listener + ) { + if (adCircuitBreakerService.isOpen()) { + listener.onFailure(new LimitExceededException(request.getId(), CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, false)); + return; + } + try { + if (!lock.tryAcquire()) { + listener.onFailure(new ClientException(request.getId(), ADCommonMessages.REQUEST_THROTTLED_MSG)); + return; + } + + try { + AnomalyDetector detector = request.getDetector(); + String detectorId = request.getId(); + Instant startTime = request.getStartTime(); + Instant endTime = request.getEndTime(); + ActionListener releaseListener = ActionListener.runAfter(listener, () -> lock.release()); + if (detector != null) { + String error = validateDetector(detector); + if (StringUtils.isNotBlank(error)) { + listener.onFailure(new OpenSearchStatusException(error, RestStatus.BAD_REQUEST)); + lock.release(); + return; + } + anomalyDetectorRunner + .executeDetector( + detector, + startTime, + endTime, + context, + getPreviewDetectorActionListener(releaseListener, detector) + ); + } else { + previewAnomalyDetector(releaseListener, detectorId, detector, startTime, endTime, context); + } + } catch (Exception e) { + logger.error("Fail to preview", e); + lock.release(); + } + } catch (Exception e) { + logger.error(e); + listener.onFailure(e); + } + } + + private String validateDetector(AnomalyDetector detector) { + if (detector.getFeatureAttributes().isEmpty()) { + return "Can't preview detector without feature"; + } else { + return RestHandlerUtils.checkFeaturesSyntax(detector, maxAnomalyFeatures); + } + } + + private ActionListener> getPreviewDetectorActionListener( + ActionListener listener, + AnomalyDetector detector + ) { + return ActionListener.wrap(new CheckedConsumer, Exception>() { + @Override + public void accept(List anomalyResult) throws Exception { + PreviewAnomalyDetectorResponse response = new PreviewAnomalyDetectorResponse(anomalyResult, detector); + listener.onResponse(response); + } + }, exception -> { + logger.error("Unexpected error running anomaly detector " + detector.getId(), exception); + listener + .onFailure( + new OpenSearchStatusException( + "Unexpected error running anomaly detector " + detector.getId() + ". " + exception.getMessage(), + RestStatus.INTERNAL_SERVER_ERROR + ) + ); + }); + } + + private void previewAnomalyDetector( + ActionListener listener, + String detectorId, + AnomalyDetector detector, + Instant startTime, + Instant endTime, + ThreadContext.StoredContext context + ) throws IOException { + if (!StringUtils.isBlank(detectorId)) { + GetRequest getRequest = new GetRequest(CommonName.CONFIG_INDEX).id(detectorId); + client.get(getRequest, onGetAnomalyDetectorResponse(listener, startTime, endTime, context)); + } else { + anomalyDetectorRunner + .executeDetector(detector, startTime, endTime, context, getPreviewDetectorActionListener(listener, detector)); + } + } + + private ActionListener onGetAnomalyDetectorResponse( + ActionListener listener, + Instant startTime, + Instant endTime, + ThreadContext.StoredContext context + ) { + return ActionListener.wrap(new CheckedConsumer() { + @Override + public void accept(GetResponse response) throws Exception { + if (!response.isExists()) { + listener + .onFailure( + new OpenSearchStatusException("Can't find anomaly detector with id:" + response.getId(), RestStatus.NOT_FOUND) + ); + return; + } + + try { + XContentParser parser = RestHandlerUtils + .createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef()); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + AnomalyDetector detector = AnomalyDetector.parse(parser, response.getId(), response.getVersion()); + + anomalyDetectorRunner + .executeDetector(detector, startTime, endTime, context, getPreviewDetectorActionListener(listener, detector)); + } catch (IOException e) { + listener.onFailure(e); + } + } + }, exception -> { listener.onFailure(new TimeSeriesException("Could not execute get query to find detector")); }); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ProfileAction.java-e b/src/main/java/org/opensearch/ad/transport/ProfileAction.java-e new file mode 100644 index 000000000..291dd0982 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ProfileAction.java-e @@ -0,0 +1,32 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; + +/** + * Profile transport action + */ +public class ProfileAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detectors/profile"; + public static final ProfileAction INSTANCE = new ProfileAction(); + + /** + * Constructor + */ + private ProfileAction() { + super(NAME, ProfileResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/ProfileNodeRequest.java b/src/main/java/org/opensearch/ad/transport/ProfileNodeRequest.java index 6ba40fe1a..d3db87d33 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileNodeRequest.java +++ b/src/main/java/org/opensearch/ad/transport/ProfileNodeRequest.java @@ -15,8 +15,8 @@ import java.util.Set; import org.opensearch.ad.model.DetectorProfileName; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.transport.TransportRequest; /** diff --git a/src/main/java/org/opensearch/ad/transport/ProfileNodeRequest.java-e b/src/main/java/org/opensearch/ad/transport/ProfileNodeRequest.java-e new file mode 100644 index 000000000..d3db87d33 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ProfileNodeRequest.java-e @@ -0,0 +1,68 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.Set; + +import org.opensearch.ad.model.DetectorProfileName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + +/** + * Class representing a nodes's profile request + */ +public class ProfileNodeRequest extends TransportRequest { + private ProfileRequest request; + + public ProfileNodeRequest(StreamInput in) throws IOException { + super(in); + this.request = new ProfileRequest(in); + } + + /** + * Constructor + * + * @param request profile request + */ + public ProfileNodeRequest(ProfileRequest request) { + this.request = request; + } + + public String getId() { + return request.getId(); + } + + /** + * Get the set that tracks which profiles should be retrieved + * + * @return the set that contains the profile names marked for retrieval + */ + public Set getProfilesToBeRetrieved() { + return request.getProfilesToBeRetrieved(); + } + + /** + * + * @return Whether this is about a multi-entity detector or not + */ + public boolean isForMultiEntityDetector() { + return request.isForMultiEntityDetector(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + request.writeTo(out); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ProfileNodeResponse.java b/src/main/java/org/opensearch/ad/transport/ProfileNodeResponse.java index a0dcae3df..9517f6add 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileNodeResponse.java +++ b/src/main/java/org/opensearch/ad/transport/ProfileNodeResponse.java @@ -19,8 +19,8 @@ import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.ModelProfile; import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.constant.CommonName; diff --git a/src/main/java/org/opensearch/ad/transport/ProfileNodeResponse.java-e b/src/main/java/org/opensearch/ad/transport/ProfileNodeResponse.java-e new file mode 100644 index 000000000..9517f6add --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ProfileNodeResponse.java-e @@ -0,0 +1,179 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import org.opensearch.action.support.nodes.BaseNodeResponse; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.ModelProfile; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentFragment; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonName; + +/** + * Profile response on a node + */ +public class ProfileNodeResponse extends BaseNodeResponse implements ToXContentFragment { + private Map modelSize; + private int shingleSize; + private long activeEntities; + private long totalUpdates; + // added after OpenSearch 1.0 + private List modelProfiles; + private long modelCount; + + /** + * Constructor + * + * @param in StreamInput + * @throws IOException throws an IO exception if the StreamInput cannot be read from + */ + public ProfileNodeResponse(StreamInput in) throws IOException { + super(in); + if (in.readBoolean()) { + modelSize = in.readMap(StreamInput::readString, StreamInput::readLong); + } + shingleSize = in.readInt(); + activeEntities = in.readVLong(); + totalUpdates = in.readVLong(); + if (in.readBoolean()) { + // added after OpenSearch 1.0 + modelProfiles = in.readList(ModelProfile::new); + modelCount = in.readVLong(); + } + } + + /** + * Constructor + * + * @param node DiscoveryNode object + * @param modelSize Mapping of model id to its memory consumption in bytes + * @param shingleSize shingle size + * @param activeEntity active entity count + * @param totalUpdates RCF model total updates + * @param modelProfiles a collection of model profiles like model size + * @param modelCount the number of models on the node + */ + public ProfileNodeResponse( + DiscoveryNode node, + Map modelSize, + int shingleSize, + long activeEntity, + long totalUpdates, + List modelProfiles, + long modelCount + ) { + super(node); + this.modelSize = modelSize; + this.shingleSize = shingleSize; + this.activeEntities = activeEntity; + this.totalUpdates = totalUpdates; + this.modelProfiles = modelProfiles; + this.modelCount = modelCount; + } + + /** + * Creates a new ProfileNodeResponse object and reads in the profile from an input stream + * + * @param in StreamInput to read from + * @return ProfileNodeResponse object corresponding to the input stream + * @throws IOException throws an IO exception if the StreamInput cannot be read from + */ + public static ProfileNodeResponse readProfiles(StreamInput in) throws IOException { + return new ProfileNodeResponse(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + if (modelSize != null) { + out.writeBoolean(true); + out.writeMap(modelSize, StreamOutput::writeString, StreamOutput::writeLong); + } else { + out.writeBoolean(false); + } + + out.writeInt(shingleSize); + out.writeVLong(activeEntities); + out.writeVLong(totalUpdates); + // added after OpenSearch 1.0 + if (modelProfiles != null) { + out.writeBoolean(true); + out.writeList(modelProfiles); + out.writeVLong(modelCount); + } else { + out.writeBoolean(false); + } + } + + /** + * Converts profile to xContent + * + * @param builder XContentBuilder + * @param params Params + * @return XContentBuilder + * @throws IOException thrown by builder for invalid field + */ + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(CommonName.MODEL_SIZE_IN_BYTES); + for (Map.Entry entry : modelSize.entrySet()) { + builder.field(entry.getKey(), entry.getValue()); + } + builder.endObject(); + + builder.field(ADCommonName.SHINGLE_SIZE, shingleSize); + builder.field(ADCommonName.ACTIVE_ENTITIES, activeEntities); + builder.field(ADCommonName.TOTAL_UPDATES, totalUpdates); + + builder.field(ADCommonName.MODEL_COUNT, modelCount); + builder.startArray(ADCommonName.MODELS); + for (ModelProfile modelProfile : modelProfiles) { + builder.startObject(); + modelProfile.toXContent(builder, params); + builder.endObject(); + } + builder.endArray(); + + return builder; + } + + public Map getModelSize() { + return modelSize; + } + + public int getShingleSize() { + return shingleSize; + } + + public long getActiveEntities() { + return activeEntities; + } + + public long getTotalUpdates() { + return totalUpdates; + } + + public List getModelProfiles() { + return modelProfiles; + } + + public long getModelCount() { + return modelCount; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ProfileRequest.java b/src/main/java/org/opensearch/ad/transport/ProfileRequest.java index f38b4399c..ea779e733 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileRequest.java +++ b/src/main/java/org/opensearch/ad/transport/ProfileRequest.java @@ -18,8 +18,8 @@ import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.ad.model.DetectorProfileName; import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; /** * implements a request to obtain profiles about an AD detector diff --git a/src/main/java/org/opensearch/ad/transport/ProfileRequest.java-e b/src/main/java/org/opensearch/ad/transport/ProfileRequest.java-e new file mode 100644 index 000000000..ea779e733 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ProfileRequest.java-e @@ -0,0 +1,97 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Set; + +import org.opensearch.action.support.nodes.BaseNodesRequest; +import org.opensearch.ad.model.DetectorProfileName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +/** + * implements a request to obtain profiles about an AD detector + */ +public class ProfileRequest extends BaseNodesRequest { + + private Set profilesToBeRetrieved; + private String detectorId; + private boolean forMultiEntityDetector; + + public ProfileRequest(StreamInput in) throws IOException { + super(in); + int size = in.readVInt(); + profilesToBeRetrieved = new HashSet(); + if (size != 0) { + for (int i = 0; i < size; i++) { + profilesToBeRetrieved.add(in.readEnum(DetectorProfileName.class)); + } + } + detectorId = in.readString(); + forMultiEntityDetector = in.readBoolean(); + } + + /** + * Constructor + * + * @param detectorId detector's id + * @param profilesToBeRetrieved profiles to be retrieved + * @param forMultiEntityDetector whether the request is for a multi-entity detector + * @param nodes nodes of nodes' profiles to be retrieved + */ + public ProfileRequest( + String detectorId, + Set profilesToBeRetrieved, + boolean forMultiEntityDetector, + DiscoveryNode... nodes + ) { + super(nodes); + this.detectorId = detectorId; + this.profilesToBeRetrieved = profilesToBeRetrieved; + this.forMultiEntityDetector = forMultiEntityDetector; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeVInt(profilesToBeRetrieved.size()); + for (DetectorProfileName profile : profilesToBeRetrieved) { + out.writeEnum(profile); + } + out.writeString(detectorId); + out.writeBoolean(forMultiEntityDetector); + } + + public String getId() { + return detectorId; + } + + /** + * Get the set that tracks which profiles should be retrieved + * + * @return the set that contains the profile names marked for retrieval + */ + public Set getProfilesToBeRetrieved() { + return profilesToBeRetrieved; + } + + /** + * + * @return Whether this is about a multi-entity detector or not + */ + public boolean isForMultiEntityDetector() { + return forMultiEntityDetector; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ProfileResponse.java b/src/main/java/org/opensearch/ad/transport/ProfileResponse.java index eb204aa6f..11ba28163 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileResponse.java +++ b/src/main/java/org/opensearch/ad/transport/ProfileResponse.java @@ -24,8 +24,8 @@ import org.opensearch.ad.model.ModelProfile; import org.opensearch.ad.model.ModelProfileOnNode; import org.opensearch.cluster.ClusterName; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/transport/ProfileResponse.java-e b/src/main/java/org/opensearch/ad/transport/ProfileResponse.java-e new file mode 100644 index 000000000..11ba28163 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ProfileResponse.java-e @@ -0,0 +1,201 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.nodes.BaseNodesResponse; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.ModelProfile; +import org.opensearch.ad.model.ModelProfileOnNode; +import org.opensearch.cluster.ClusterName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentFragment; +import org.opensearch.core.xcontent.XContentBuilder; + +/** + * This class consists of the aggregated responses from the nodes + */ +public class ProfileResponse extends BaseNodesResponse implements ToXContentFragment { + private static final Logger LOG = LogManager.getLogger(ProfileResponse.class); + // filed name in toXContent + static final String COORDINATING_NODE = ADCommonName.COORDINATING_NODE; + static final String SHINGLE_SIZE = ADCommonName.SHINGLE_SIZE; + static final String TOTAL_SIZE = ADCommonName.TOTAL_SIZE_IN_BYTES; + static final String ACTIVE_ENTITY = ADCommonName.ACTIVE_ENTITIES; + static final String MODELS = ADCommonName.MODELS; + static final String TOTAL_UPDATES = ADCommonName.TOTAL_UPDATES; + static final String MODEL_COUNT = ADCommonName.MODEL_COUNT; + + // changed from ModelProfile to ModelProfileOnNode since Opensearch 1.1 + private ModelProfileOnNode[] modelProfile; + private int shingleSize; + private String coordinatingNode; + private long totalSizeInBytes; + private long activeEntities; + private long totalUpdates; + // added since 1.1 + private long modelCount; + + /** + * Constructor + * + * @param in StreamInput + * @throws IOException thrown when unable to read from stream + */ + public ProfileResponse(StreamInput in) throws IOException { + super(in); + int size = in.readVInt(); + modelProfile = new ModelProfileOnNode[size]; + for (int i = 0; i < size; i++) { + modelProfile[i] = new ModelProfileOnNode(in); + } + + shingleSize = in.readInt(); + coordinatingNode = in.readString(); + totalSizeInBytes = in.readVLong(); + activeEntities = in.readVLong(); + totalUpdates = in.readVLong(); + modelCount = in.readVLong(); + } + + /** + * Constructor + * + * @param clusterName name of cluster + * @param nodes List of ProfileNodeResponse from nodes + * @param failures List of failures from nodes + */ + public ProfileResponse(ClusterName clusterName, List nodes, List failures) { + super(clusterName, nodes, failures); + totalSizeInBytes = 0L; + activeEntities = 0L; + totalUpdates = 0L; + shingleSize = -1; + modelCount = 0; + List modelProfileList = new ArrayList<>(); + for (ProfileNodeResponse response : nodes) { + String curNodeId = response.getNode().getId(); + if (response.getShingleSize() >= 0) { + coordinatingNode = curNodeId; + shingleSize = response.getShingleSize(); + } + if (response.getModelSize() != null) { + for (Map.Entry entry : response.getModelSize().entrySet()) { + totalSizeInBytes += entry.getValue(); + } + } + if (response.getModelProfiles() != null && response.getModelProfiles().size() > 0) { + modelCount += response.getModelCount(); + for (ModelProfile profile : response.getModelProfiles()) { + modelProfileList.add(new ModelProfileOnNode(curNodeId, profile)); + } + } else if (response.getModelSize() != null && response.getModelSize().size() > 0) { + for (Map.Entry entry : response.getModelSize().entrySet()) { + // single-stream detectors have no entity info + modelProfileList.add(new ModelProfileOnNode(curNodeId, new ModelProfile(entry.getKey(), null, entry.getValue()))); + } + } + + if (response.getActiveEntities() > 0) { + activeEntities += response.getActiveEntities(); + } + if (response.getTotalUpdates() > totalUpdates) { + totalUpdates = response.getTotalUpdates(); + } + } + if (coordinatingNode == null) { + coordinatingNode = ""; + } + this.modelProfile = modelProfileList.toArray(new ModelProfileOnNode[0]); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeVInt(modelProfile.length); + + for (ModelProfileOnNode profile : modelProfile) { + profile.writeTo(out); + } + + out.writeInt(shingleSize); + out.writeString(coordinatingNode); + out.writeVLong(totalSizeInBytes); + out.writeVLong(activeEntities); + out.writeVLong(totalUpdates); + out.writeVLong(modelCount); + } + + @Override + public void writeNodesTo(StreamOutput out, List nodes) throws IOException { + out.writeList(nodes); + } + + @Override + public List readNodesFrom(StreamInput in) throws IOException { + return in.readList(ProfileNodeResponse::readProfiles); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(COORDINATING_NODE, coordinatingNode); + builder.field(SHINGLE_SIZE, shingleSize); + builder.field(TOTAL_SIZE, totalSizeInBytes); + builder.field(ACTIVE_ENTITY, activeEntities); + builder.field(TOTAL_UPDATES, totalUpdates); + if (modelCount > 0) { + builder.field(MODEL_COUNT, modelCount); + } + builder.startArray(MODELS); + for (ModelProfileOnNode profile : modelProfile) { + profile.toXContent(builder, params); + } + builder.endArray(); + return builder; + } + + public ModelProfileOnNode[] getModelProfile() { + return modelProfile; + } + + public int getShingleSize() { + return shingleSize; + } + + public long getActiveEntities() { + return activeEntities; + } + + public long getTotalUpdates() { + return totalUpdates; + } + + public String getCoordinatingNode() { + return coordinatingNode; + } + + public long getTotalSizeInBytes() { + return totalSizeInBytes; + } + + public long getModelCount() { + return modelCount; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java b/src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java index 34f0eef87..e05251f2f 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java @@ -30,8 +30,8 @@ import org.opensearch.ad.model.ModelProfile; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; diff --git a/src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java-e new file mode 100644 index 000000000..69c1c6f52 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java-e @@ -0,0 +1,155 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_MODEL_SIZE_PER_NODE; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.model.DetectorProfileName; +import org.opensearch.ad.model.ModelProfile; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +/** + * This class contains the logic to extract the stats from the nodes + */ +public class ProfileTransportAction extends TransportNodesAction { + private static final Logger LOG = LogManager.getLogger(ProfileTransportAction.class); + private ModelManager modelManager; + private FeatureManager featureManager; + private CacheProvider cacheProvider; + // the number of models to return. Defaults to 10. + private volatile int numModelsToReturn; + + /** + * Constructor + * + * @param threadPool ThreadPool to use + * @param clusterService ClusterService + * @param transportService TransportService + * @param actionFilters Action Filters + * @param modelManager model manager object + * @param featureManager feature manager object + * @param cacheProvider cache provider + * @param settings Node settings accessor + */ + @Inject + public ProfileTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + ModelManager modelManager, + FeatureManager featureManager, + CacheProvider cacheProvider, + Settings settings + ) { + super( + ProfileAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + ProfileRequest::new, + ProfileNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + ProfileNodeResponse.class + ); + this.modelManager = modelManager; + this.featureManager = featureManager; + this.cacheProvider = cacheProvider; + this.numModelsToReturn = MAX_MODEL_SIZE_PER_NODE.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_MODEL_SIZE_PER_NODE, it -> this.numModelsToReturn = it); + } + + @Override + protected ProfileResponse newResponse(ProfileRequest request, List responses, List failures) { + return new ProfileResponse(clusterService.getClusterName(), responses, failures); + } + + @Override + protected ProfileNodeRequest newNodeRequest(ProfileRequest request) { + return new ProfileNodeRequest(request); + } + + @Override + protected ProfileNodeResponse newNodeResponse(StreamInput in) throws IOException { + return new ProfileNodeResponse(in); + } + + @Override + protected ProfileNodeResponse nodeOperation(ProfileNodeRequest request) { + String detectorId = request.getId(); + Set profiles = request.getProfilesToBeRetrieved(); + int shingleSize = -1; + long activeEntity = 0; + long totalUpdates = 0; + Map modelSize = null; + List modelProfiles = null; + int modelCount = 0; + if (request.isForMultiEntityDetector()) { + if (profiles.contains(DetectorProfileName.ACTIVE_ENTITIES)) { + activeEntity = cacheProvider.get().getActiveEntities(detectorId); + } + if (profiles.contains(DetectorProfileName.INIT_PROGRESS)) { + totalUpdates = cacheProvider.get().getTotalUpdates(detectorId);// get toal updates + } + if (profiles.contains(DetectorProfileName.TOTAL_SIZE_IN_BYTES)) { + modelSize = cacheProvider.get().getModelSize(detectorId); + } + // need to provide entity info for HCAD + if (profiles.contains(DetectorProfileName.MODELS)) { + modelProfiles = cacheProvider.get().getAllModelProfile(detectorId); + modelCount = modelProfiles.size(); + int limit = Math.min(numModelsToReturn, modelCount); + if (limit != modelCount) { + LOG.info("model number limit reached"); + modelProfiles = modelProfiles.subList(0, limit); + } + } + } else { + if (profiles.contains(DetectorProfileName.COORDINATING_NODE) || profiles.contains(DetectorProfileName.SHINGLE_SIZE)) { + shingleSize = featureManager.getShingleSize(detectorId); + } + + if (profiles.contains(DetectorProfileName.TOTAL_SIZE_IN_BYTES) || profiles.contains(DetectorProfileName.MODELS)) { + modelSize = modelManager.getModelSize(detectorId); + } + } + + return new ProfileNodeResponse( + clusterService.localNode(), + modelSize, + shingleSize, + activeEntity, + totalUpdates, + modelProfiles, + modelCount + ); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/RCFPollingAction.java-e b/src/main/java/org/opensearch/ad/transport/RCFPollingAction.java-e new file mode 100644 index 000000000..147ff74cb --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/RCFPollingAction.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; + +public class RCFPollingAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "rcfpolling"; + public static final RCFPollingAction INSTANCE = new RCFPollingAction(); + + private RCFPollingAction() { + super(NAME, RCFPollingResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/RCFPollingRequest.java b/src/main/java/org/opensearch/ad/transport/RCFPollingRequest.java index a0e2c1e49..fdf2055cf 100644 --- a/src/main/java/org/opensearch/ad/transport/RCFPollingRequest.java +++ b/src/main/java/org/opensearch/ad/transport/RCFPollingRequest.java @@ -19,9 +19,9 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/transport/RCFPollingRequest.java-e b/src/main/java/org/opensearch/ad/transport/RCFPollingRequest.java-e new file mode 100644 index 000000000..95f429efa --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/RCFPollingRequest.java-e @@ -0,0 +1,67 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +public class RCFPollingRequest extends ActionRequest implements ToXContentObject { + private String adID; + + public RCFPollingRequest(StreamInput in) throws IOException { + super(in); + adID = in.readString(); + } + + public RCFPollingRequest(String adID) { + super(); + this.adID = adID; + } + + public String getAdID() { + return adID; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(adID); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (Strings.isEmpty(adID)) { + validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + } + return validationException; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ADCommonName.ID_JSON_KEY, adID); + builder.endObject(); + return builder; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/RCFPollingResponse.java b/src/main/java/org/opensearch/ad/transport/RCFPollingResponse.java index 7c7ec9b2e..f3d7a43da 100644 --- a/src/main/java/org/opensearch/ad/transport/RCFPollingResponse.java +++ b/src/main/java/org/opensearch/ad/transport/RCFPollingResponse.java @@ -14,8 +14,8 @@ import java.io.IOException; import org.opensearch.action.ActionResponse; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/transport/RCFPollingResponse.java-e b/src/main/java/org/opensearch/ad/transport/RCFPollingResponse.java-e new file mode 100644 index 000000000..f3d7a43da --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/RCFPollingResponse.java-e @@ -0,0 +1,52 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +public class RCFPollingResponse extends ActionResponse implements ToXContentObject { + public static final String TOTAL_UPDATES_KEY = "totalUpdates"; + + private final long totalUpdates; + + public RCFPollingResponse(long totalUpdates) { + this.totalUpdates = totalUpdates; + } + + public RCFPollingResponse(StreamInput in) throws IOException { + super(in); + totalUpdates = in.readVLong(); + } + + public long getTotalUpdates() { + return totalUpdates; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVLong(totalUpdates); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TOTAL_UPDATES_KEY, totalUpdates); + builder.endObject(); + return builder; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/RCFPollingTransportAction.java b/src/main/java/org/opensearch/ad/transport/RCFPollingTransportAction.java index 5f2403b00..49e6f0153 100644 --- a/src/main/java/org/opensearch/ad/transport/RCFPollingTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/RCFPollingTransportAction.java @@ -27,8 +27,8 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.common.exception.TimeSeriesException; diff --git a/src/main/java/org/opensearch/ad/transport/RCFPollingTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/RCFPollingTransportAction.java-e new file mode 100644 index 000000000..bd07eda77 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/RCFPollingTransportAction.java-e @@ -0,0 +1,147 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.Locale; +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.SingleStreamModelIdMapper; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.settings.Settings; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.transport.TransportException; +import org.opensearch.transport.TransportRequestOptions; +import org.opensearch.transport.TransportResponseHandler; +import org.opensearch.transport.TransportService; + +/** + * Transport action to get total rcf updates from hosted models or checkpoint + * + */ +public class RCFPollingTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(RCFPollingTransportAction.class); + static final String NO_NODE_FOUND_MSG = "Cannot find model hosting node"; + static final String FAIL_TO_GET_RCF_UPDATE_MSG = "Cannot find hosted model or related checkpoint"; + + private final TransportService transportService; + private final ModelManager modelManager; + private final HashRing hashRing; + private final TransportRequestOptions option; + private final ClusterService clusterService; + + @Inject + public RCFPollingTransportAction( + ActionFilters actionFilters, + TransportService transportService, + Settings settings, + ModelManager modelManager, + HashRing hashRing, + ClusterService clusterService + ) { + super(RCFPollingAction.NAME, transportService, actionFilters, RCFPollingRequest::new); + this.transportService = transportService; + this.modelManager = modelManager; + this.hashRing = hashRing; + this.option = TransportRequestOptions + .builder() + .withType(TransportRequestOptions.Type.REG) + .withTimeout(AnomalyDetectorSettings.REQUEST_TIMEOUT.get(settings)) + .build(); + this.clusterService = clusterService; + } + + @Override + protected void doExecute(Task task, RCFPollingRequest request, ActionListener listener) { + + String adID = request.getAdID(); + + String rcfModelID = SingleStreamModelIdMapper.getRcfModelId(adID, 0); + + Optional rcfNode = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(rcfModelID); + if (!rcfNode.isPresent()) { + listener.onFailure(new TimeSeriesException(adID, NO_NODE_FOUND_MSG)); + return; + } + + String rcfNodeId = rcfNode.get().getId(); + + DiscoveryNode localNode = clusterService.localNode(); + + if (localNode.getId().equals(rcfNodeId)) { + modelManager + .getTotalUpdates( + rcfModelID, + adID, + ActionListener + .wrap( + totalUpdates -> listener.onResponse(new RCFPollingResponse(totalUpdates)), + e -> listener.onFailure(new TimeSeriesException(adID, FAIL_TO_GET_RCF_UPDATE_MSG, e)) + ) + ); + } else if (request.remoteAddress() == null) { + // redirect if request comes from local host. + // If a request comes from remote machine, it is already redirected. + // One redirection should be enough. + // We don't want a potential infinite loop due to any bug and thus give up. + LOG.info("Sending RCF polling request to {} for model {}", rcfNodeId, rcfModelID); + + try { + transportService + .sendRequest(rcfNode.get(), RCFPollingAction.NAME, request, option, new TransportResponseHandler() { + + @Override + public RCFPollingResponse read(StreamInput in) throws IOException { + return new RCFPollingResponse(in); + } + + @Override + public void handleResponse(RCFPollingResponse response) { + listener.onResponse(response); + } + + @Override + public void handleException(TransportException exp) { + listener.onFailure(exp); + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + + }); + } catch (Exception e) { + LOG.error(String.format(Locale.ROOT, "Fail to poll RCF models for {}", adID), e); + listener.onFailure(new TimeSeriesException(adID, FAIL_TO_GET_RCF_UPDATE_MSG, e)); + } + + } else { + LOG.error("Fail to poll rcf for model {} due to an unexpected bug.", rcfModelID); + listener.onFailure(new TimeSeriesException(adID, NO_NODE_FOUND_MSG)); + } + } +} diff --git a/src/main/java/org/opensearch/ad/transport/RCFResultAction.java-e b/src/main/java/org/opensearch/ad/transport/RCFResultAction.java-e new file mode 100644 index 000000000..3480e880a --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/RCFResultAction.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; + +public class RCFResultAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "rcf/result"; + public static final RCFResultAction INSTANCE = new RCFResultAction(); + + private RCFResultAction() { + super(NAME, RCFResultResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/RCFResultRequest.java b/src/main/java/org/opensearch/ad/transport/RCFResultRequest.java index d06f1eae8..b617704b8 100644 --- a/src/main/java/org/opensearch/ad/transport/RCFResultRequest.java +++ b/src/main/java/org/opensearch/ad/transport/RCFResultRequest.java @@ -19,9 +19,9 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.constant.CommonName; diff --git a/src/main/java/org/opensearch/ad/transport/RCFResultRequest.java-e b/src/main/java/org/opensearch/ad/transport/RCFResultRequest.java-e new file mode 100644 index 000000000..70efd2b0d --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/RCFResultRequest.java-e @@ -0,0 +1,106 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonName; + +public class RCFResultRequest extends ActionRequest implements ToXContentObject { + private String adID; + private String modelID; + private double[] features; + + // Messages used for validation error + public static final String INVALID_FEATURE_MSG = "feature vector is empty"; + + public RCFResultRequest(StreamInput in) throws IOException { + super(in); + adID = in.readString(); + modelID = in.readString(); + int size = in.readVInt(); + features = new double[size]; + for (int i = 0; i < size; i++) { + features[i] = in.readDouble(); + } + } + + public RCFResultRequest(String adID, String modelID, double[] features) { + super(); + this.adID = adID; + this.modelID = modelID; + this.features = features; + } + + public double[] getFeatures() { + return features; + } + + public String getAdID() { + return adID; + } + + public String getModelID() { + return modelID; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(adID); + out.writeString(modelID); + out.writeVInt(features.length); + for (double feature : features) { + out.writeDouble(feature); + } + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (features == null || features.length == 0) { + validationException = addValidationError(RCFResultRequest.INVALID_FEATURE_MSG, validationException); + } + if (Strings.isEmpty(adID)) { + validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + } + if (Strings.isEmpty(modelID)) { + validationException = addValidationError(ADCommonMessages.MODEL_ID_MISSING_MSG, validationException); + } + return validationException; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ADCommonName.ID_JSON_KEY, adID); + builder.field(CommonName.MODEL_ID_FIELD, modelID); + builder.startArray(ADCommonName.FEATURE_JSON_KEY); + for (double feature : features) { + builder.value(feature); + } + builder.endArray(); + builder.endObject(); + return builder; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/RCFResultResponse.java b/src/main/java/org/opensearch/ad/transport/RCFResultResponse.java index 769ad5f4f..f6484b840 100644 --- a/src/main/java/org/opensearch/ad/transport/RCFResultResponse.java +++ b/src/main/java/org/opensearch/ad/transport/RCFResultResponse.java @@ -16,8 +16,8 @@ import org.opensearch.Version; import org.opensearch.action.ActionResponse; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/transport/RCFResultResponse.java-e b/src/main/java/org/opensearch/ad/transport/RCFResultResponse.java-e new file mode 100644 index 000000000..f6484b840 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/RCFResultResponse.java-e @@ -0,0 +1,222 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.Version; +import org.opensearch.action.ActionResponse; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +public class RCFResultResponse extends ActionResponse implements ToXContentObject { + public static final String RCF_SCORE_JSON_KEY = "rcfScore"; + public static final String CONFIDENCE_JSON_KEY = "confidence"; + public static final String FOREST_SIZE_JSON_KEY = "forestSize"; + public static final String ATTRIBUTION_JSON_KEY = "attribution"; + public static final String TOTAL_UPDATES_JSON_KEY = "total_updates"; + public static final String RELATIVE_INDEX_FIELD_JSON_KEY = "relativeIndex"; + public static final String PAST_VALUES_FIELD_JSON_KEY = "pastValues"; + public static final String EXPECTED_VAL_LIST_FIELD_JSON_KEY = "expectedValuesList"; + public static final String LIKELIHOOD_FIELD_JSON_KEY = "likelihoodOfValues"; + public static final String THRESHOLD_FIELD_JSON_KEY = "threshold"; + + private Double rcfScore; + private Double confidence; + private Integer forestSize; + private double[] attribution; + private Long totalUpdates = 0L; + private Version remoteAdVersion; + private Double anomalyGrade; + private Integer relativeIndex; + private double[] pastValues; + private double[][] expectedValuesList; + private double[] likelihoodOfValues; + private Double threshold; + + public RCFResultResponse( + double rcfScore, + double confidence, + int forestSize, + double[] attribution, + long totalUpdates, + double grade, + Version remoteAdVersion, + Integer relativeIndex, + double[] pastValues, + double[][] expectedValuesList, + double[] likelihoodOfValues, + Double threshold + ) { + this.rcfScore = rcfScore; + this.confidence = confidence; + this.forestSize = forestSize; + this.attribution = attribution; + this.totalUpdates = totalUpdates; + this.anomalyGrade = grade; + this.remoteAdVersion = remoteAdVersion; + this.relativeIndex = relativeIndex; + this.pastValues = pastValues; + this.expectedValuesList = expectedValuesList; + this.likelihoodOfValues = likelihoodOfValues; + this.threshold = threshold; + } + + public RCFResultResponse(StreamInput in) throws IOException { + super(in); + this.rcfScore = in.readDouble(); + this.confidence = in.readDouble(); + this.forestSize = in.readVInt(); + this.attribution = in.readDoubleArray(); + if (in.available() > 0) { + this.totalUpdates = in.readLong(); + this.anomalyGrade = in.readDouble(); + this.relativeIndex = in.readOptionalInt(); + + if (in.readBoolean()) { + this.pastValues = in.readDoubleArray(); + } else { + this.pastValues = null; + } + + if (in.readBoolean()) { + int numberofExpectedVals = in.readVInt(); + this.expectedValuesList = new double[numberofExpectedVals][]; + for (int i = 0; i < numberofExpectedVals; i++) { + expectedValuesList[i] = in.readDoubleArray(); + } + } else { + this.expectedValuesList = null; + } + + if (in.readBoolean()) { + this.likelihoodOfValues = in.readDoubleArray(); + } else { + this.likelihoodOfValues = null; + } + + this.threshold = in.readOptionalDouble(); + } + } + + public Double getRCFScore() { + return rcfScore; + } + + public Double getConfidence() { + return confidence; + } + + public Integer getForestSize() { + return forestSize; + } + + /** + * Returns RCF score attribution. Can be null when anomaly grade is less than + * or equals to 0. + * + * @return RCF score attribution. + */ + public double[] getAttribution() { + return attribution; + } + + public Long getTotalUpdates() { + return totalUpdates; + } + + public Double getAnomalyGrade() { + return anomalyGrade; + } + + public Integer getRelativeIndex() { + return relativeIndex; + } + + public double[] getPastValues() { + return pastValues; + } + + public double[][] getExpectedValuesList() { + return expectedValuesList; + } + + public double[] getLikelihoodOfValues() { + return likelihoodOfValues; + } + + public Double getThreshold() { + return threshold; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(rcfScore); + out.writeDouble(confidence); + out.writeVInt(forestSize); + out.writeDoubleArray(attribution); + if (remoteAdVersion != null) { + out.writeLong(totalUpdates); + out.writeDouble(anomalyGrade); + out.writeOptionalInt(relativeIndex); + + if (pastValues != null) { + out.writeBoolean(true); + out.writeDoubleArray(pastValues); + } else { + out.writeBoolean(false); + } + + if (expectedValuesList != null) { + out.writeBoolean(true); + int numberofExpectedVals = expectedValuesList.length; + out.writeVInt(expectedValuesList.length); + for (int i = 0; i < numberofExpectedVals; i++) { + out.writeDoubleArray(expectedValuesList[i]); + } + } else { + out.writeBoolean(false); + } + + if (likelihoodOfValues != null) { + out.writeBoolean(true); + out.writeDoubleArray(likelihoodOfValues); + } else { + out.writeBoolean(false); + } + + out.writeOptionalDouble(threshold); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(RCF_SCORE_JSON_KEY, rcfScore); + builder.field(CONFIDENCE_JSON_KEY, confidence); + builder.field(FOREST_SIZE_JSON_KEY, forestSize); + builder.field(ATTRIBUTION_JSON_KEY, attribution); + builder.field(TOTAL_UPDATES_JSON_KEY, totalUpdates); + builder.field(ADCommonName.ANOMALY_GRADE_JSON_KEY, anomalyGrade); + builder.field(RELATIVE_INDEX_FIELD_JSON_KEY, relativeIndex); + builder.field(PAST_VALUES_FIELD_JSON_KEY, pastValues); + builder.field(EXPECTED_VAL_LIST_FIELD_JSON_KEY, expectedValuesList); + builder.field(LIKELIHOOD_FIELD_JSON_KEY, likelihoodOfValues); + builder.field(THRESHOLD_FIELD_JSON_KEY, threshold); + builder.endObject(); + return builder; + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/RCFResultTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/RCFResultTransportAction.java-e new file mode 100644 index 000000000..f9d63365c --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/RCFResultTransportAction.java-e @@ -0,0 +1,129 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.net.ConnectException; +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.stats.ADStats; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.inject.Inject; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.transport.TransportService; + +public class RCFResultTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(RCFResultTransportAction.class); + private ModelManager manager; + private ADCircuitBreakerService adCircuitBreakerService; + private HashRing hashRing; + private ADStats adStats; + + @Inject + public RCFResultTransportAction( + ActionFilters actionFilters, + TransportService transportService, + ModelManager manager, + ADCircuitBreakerService adCircuitBreakerService, + HashRing hashRing, + ADStats adStats + ) { + super(RCFResultAction.NAME, transportService, actionFilters, RCFResultRequest::new); + this.manager = manager; + this.adCircuitBreakerService = adCircuitBreakerService; + this.hashRing = hashRing; + this.adStats = adStats; + } + + @Override + protected void doExecute(Task task, RCFResultRequest request, ActionListener listener) { + if (adCircuitBreakerService.isOpen()) { + listener.onFailure(new LimitExceededException(request.getAdID(), CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG)); + return; + } + Optional remoteNode = hashRing.getNodeByAddress(request.remoteAddress()); + if (!remoteNode.isPresent()) { + listener.onFailure(new ConnectException("Can't find remote node by address")); + return; + } + String remoteNodeId = remoteNode.get().getId(); + Version remoteAdVersion = hashRing.getAdVersion(remoteNodeId); + + try { + LOG.info("Serve rcf request for {}", request.getModelID()); + manager + .getTRcfResult( + request.getAdID(), + request.getModelID(), + request.getFeatures(), + ActionListener + .wrap( + result -> listener + .onResponse( + new RCFResultResponse( + result.getRcfScore(), + result.getConfidence(), + result.getForestSize(), + result.getRelevantAttribution(), + result.getTotalUpdates(), + result.getGrade(), + remoteAdVersion, + result.getRelativeIndex(), + result.getPastValues(), + result.getExpectedValuesList(), + result.getLikelihoodOfValues(), + result.getThreshold() + ) + ), + exception -> { + if (exception instanceof IllegalArgumentException) { + // fail to score likely due to model corruption. Re-cold start to recover. + LOG.error(new ParameterizedMessage("Likely model corruption for [{}]", request.getAdID()), exception); + adStats.getStat(StatNames.MODEL_CORRUTPION_COUNT.getName()).increment(); + manager + .clear( + request.getAdID(), + ActionListener + .wrap( + r -> LOG.info("Deleted model for [{}] with response [{}] ", request.getAdID(), r), + ex -> LOG.error("Fail to delete model for " + request.getAdID(), ex) + ) + ); + listener.onFailure(exception); + } else { + LOG.warn(exception); + listener.onFailure(exception); + } + } + ) + ); + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/SearchADTasksAction.java-e b/src/main/java/org/opensearch/ad/transport/SearchADTasksAction.java-e new file mode 100644 index 000000000..50b597e62 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/SearchADTasksAction.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.constant.CommonValue; + +public class SearchADTasksAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "tasks/search"; + public static final SearchADTasksAction INSTANCE = new SearchADTasksAction(); + + private SearchADTasksAction() { + super(NAME, SearchResponse::new); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/SearchADTasksTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/SearchADTasksTransportAction.java-e new file mode 100644 index 000000000..15ebdb26a --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/SearchADTasksTransportAction.java-e @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.transport.handler.ADSearchHandler; +import org.opensearch.common.inject.Inject; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +public class SearchADTasksTransportAction extends HandledTransportAction { + private ADSearchHandler searchHandler; + + @Inject + public SearchADTasksTransportAction(TransportService transportService, ActionFilters actionFilters, ADSearchHandler searchHandler) { + super(SearchADTasksAction.NAME, transportService, actionFilters, SearchRequest::new); + this.searchHandler = searchHandler; + } + + @Override + protected void doExecute(Task task, SearchRequest request, ActionListener listener) { + searchHandler.search(request, listener); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorAction.java-e b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorAction.java-e new file mode 100644 index 000000000..c15ece9ab --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorAction.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.constant.CommonValue; + +public class SearchAnomalyDetectorAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/search"; + public static final SearchAnomalyDetectorAction INSTANCE = new SearchAnomalyDetectorAction(); + + private SearchAnomalyDetectorAction() { + super(NAME, SearchResponse::new); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoAction.java-e b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoAction.java-e new file mode 100644 index 000000000..3f4f7c2fc --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoAction.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; + +public class SearchAnomalyDetectorInfoAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/info"; + public static final SearchAnomalyDetectorInfoAction INSTANCE = new SearchAnomalyDetectorInfoAction(); + + private SearchAnomalyDetectorInfoAction() { + super(NAME, SearchAnomalyDetectorInfoResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoRequest.java b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoRequest.java index bd4ed7993..8289619c1 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoRequest.java +++ b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoRequest.java @@ -15,8 +15,8 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; public class SearchAnomalyDetectorInfoRequest extends ActionRequest { diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoRequest.java-e b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoRequest.java-e new file mode 100644 index 000000000..8289619c1 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoRequest.java-e @@ -0,0 +1,57 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class SearchAnomalyDetectorInfoRequest extends ActionRequest { + + private String name; + private String rawPath; + + public SearchAnomalyDetectorInfoRequest(StreamInput in) throws IOException { + super(in); + name = in.readOptionalString(); + rawPath = in.readString(); + } + + public SearchAnomalyDetectorInfoRequest(String name, String rawPath) throws IOException { + super(); + this.name = name; + this.rawPath = rawPath; + } + + public String getName() { + return name; + } + + public String getRawPath() { + return rawPath; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeOptionalString(name); + out.writeString(rawPath); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoResponse.java b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoResponse.java index 90139ed12..3fdec11c6 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoResponse.java +++ b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoResponse.java @@ -14,8 +14,8 @@ import java.io.IOException; import org.opensearch.action.ActionResponse; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.util.RestHandlerUtils; diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoResponse.java-e b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoResponse.java-e new file mode 100644 index 000000000..3fdec11c6 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoResponse.java-e @@ -0,0 +1,60 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.util.RestHandlerUtils; + +public class SearchAnomalyDetectorInfoResponse extends ActionResponse implements ToXContentObject { + private long count; + private boolean nameExists; + + public SearchAnomalyDetectorInfoResponse(StreamInput in) throws IOException { + super(in); + count = in.readLong(); + nameExists = in.readBoolean(); + } + + public SearchAnomalyDetectorInfoResponse(long count, boolean nameExists) { + this.count = count; + this.nameExists = nameExists; + } + + public long getCount() { + return count; + } + + public boolean isNameExists() { + return nameExists; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeLong(count); + out.writeBoolean(nameExists); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(RestHandlerUtils.COUNT, count); + builder.field(RestHandlerUtils.MATCH, nameExists); + builder.endObject(); + return builder; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoTransportAction.java-e new file mode 100644 index 000000000..1570aa359 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoTransportAction.java-e @@ -0,0 +1,126 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_GET_DETECTOR_INFO; +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermsQueryBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +public class SearchAnomalyDetectorInfoTransportAction extends + HandledTransportAction { + private static final Logger LOG = LogManager.getLogger(SearchAnomalyDetectorInfoTransportAction.class); + private final Client client; + private final ClusterService clusterService; + + @Inject + public SearchAnomalyDetectorInfoTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ClusterService clusterService + ) { + super(SearchAnomalyDetectorInfoAction.NAME, transportService, actionFilters, SearchAnomalyDetectorInfoRequest::new); + this.client = client; + this.clusterService = clusterService; + } + + @Override + protected void doExecute( + Task task, + SearchAnomalyDetectorInfoRequest request, + ActionListener actionListener + ) { + String name = request.getName(); + String rawPath = request.getRawPath(); + ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_GET_DETECTOR_INFO); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + SearchRequest searchRequest = new SearchRequest().indices(CommonName.CONFIG_INDEX); + if (rawPath.endsWith(RestHandlerUtils.COUNT)) { + // Count detectors + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchRequest.source(searchSourceBuilder); + client.search(searchRequest, new ActionListener() { + + @Override + public void onResponse(SearchResponse searchResponse) { + SearchAnomalyDetectorInfoResponse response = new SearchAnomalyDetectorInfoResponse( + searchResponse.getHits().getTotalHits().value, + false + ); + listener.onResponse(response); + } + + @Override + public void onFailure(Exception e) { + if (e.getClass() == IndexNotFoundException.class) { + // Anomaly Detectors index does not exist + // Could be that user is creating first detector + SearchAnomalyDetectorInfoResponse response = new SearchAnomalyDetectorInfoResponse(0, false); + listener.onResponse(response); + } else { + listener.onFailure(e); + } + } + }); + } else { + // Match name with existing detectors + TermsQueryBuilder query = QueryBuilders.termsQuery("name.keyword", name); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); + searchRequest.source(searchSourceBuilder); + client.search(searchRequest, new ActionListener() { + + @Override + public void onResponse(SearchResponse searchResponse) { + boolean nameExists = false; + nameExists = searchResponse.getHits().getTotalHits().value > 0; + SearchAnomalyDetectorInfoResponse response = new SearchAnomalyDetectorInfoResponse(0, nameExists); + listener.onResponse(response); + } + + @Override + public void onFailure(Exception e) { + if (e.getClass() == IndexNotFoundException.class) { + // Anomaly Detectors index does not exist + // Could be that user is creating first detector + SearchAnomalyDetectorInfoResponse response = new SearchAnomalyDetectorInfoResponse(0, false); + listener.onResponse(response); + } else { + listener.onFailure(e); + } + } + }); + } + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + } +} diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorTransportAction.java-e new file mode 100644 index 000000000..160662cf9 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorTransportAction.java-e @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.transport.handler.ADSearchHandler; +import org.opensearch.common.inject.Inject; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +public class SearchAnomalyDetectorTransportAction extends HandledTransportAction { + private ADSearchHandler searchHandler; + + @Inject + public SearchAnomalyDetectorTransportAction( + TransportService transportService, + ActionFilters actionFilters, + ADSearchHandler searchHandler + ) { + super(SearchAnomalyDetectorAction.NAME, transportService, actionFilters, SearchRequest::new); + this.searchHandler = searchHandler; + } + + @Override + protected void doExecute(Task task, SearchRequest request, ActionListener listener) { + searchHandler.search(request, listener); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyResultAction.java-e b/src/main/java/org/opensearch/ad/transport/SearchAnomalyResultAction.java-e new file mode 100644 index 000000000..7e0178393 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/SearchAnomalyResultAction.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.constant.CommonValue; + +public class SearchAnomalyResultAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "result/search"; + public static final SearchAnomalyResultAction INSTANCE = new SearchAnomalyResultAction(); + + private SearchAnomalyResultAction() { + super(NAME, SearchResponse::new); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyResultTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/SearchAnomalyResultTransportAction.java-e new file mode 100644 index 000000000..1043754c2 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/SearchAnomalyResultTransportAction.java-e @@ -0,0 +1,285 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.constant.ADCommonName.CUSTOM_RESULT_INDEX_PREFIX; +import static org.opensearch.ad.indices.ADIndexManagement.ALL_AD_RESULTS_INDEX_PATTERN; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_DETECTOR_UPPER_LIMIT; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.MultiSearchRequest; +import org.opensearch.action.search.MultiSearchResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.transport.handler.ADSearchHandler; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.bucket.terms.StringTerms; +import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.transport.TransportService; + +import com.google.common.annotations.VisibleForTesting; + +public class SearchAnomalyResultTransportAction extends HandledTransportAction { + public static final String RESULT_INDEX_AGG_NAME = "result_index"; + + private final Logger logger = LogManager.getLogger(SearchAnomalyResultTransportAction.class); + private ADSearchHandler searchHandler; + private final ClusterService clusterService; + private final IndexNameExpressionResolver indexNameExpressionResolver; + private final Client client; + + @Inject + public SearchAnomalyResultTransportAction( + TransportService transportService, + ActionFilters actionFilters, + ADSearchHandler searchHandler, + ClusterService clusterService, + IndexNameExpressionResolver indexNameExpressionResolver, + Client client + ) { + super(SearchAnomalyResultAction.NAME, transportService, actionFilters, SearchRequest::new); + this.searchHandler = searchHandler; + this.clusterService = clusterService; + this.indexNameExpressionResolver = indexNameExpressionResolver; + this.client = client; + } + + @VisibleForTesting + boolean validateIndexAndReturnOnlyQueryCustomResult(String[] indices) { + if (indices == null || indices.length == 0) { + throw new IllegalArgumentException("No indices set in search request"); + } + // Set query indices as default result indices, will check custom result indices permission and add + // custom indices which user has search permission later. + + boolean onlyQueryCustomResultIndex = true; + for (String indexName : indices) { + // If only query custom result index, don't need to set ALL_AD_RESULTS_INDEX_PATTERN in search request + if (ALL_AD_RESULTS_INDEX_PATTERN.equals(indexName)) { + onlyQueryCustomResultIndex = false; + } + } + return onlyQueryCustomResultIndex; + } + + @VisibleForTesting + void calculateCustomResultIndices(Set customResultIndices, String[] indices) { + String[] concreteIndices = indexNameExpressionResolver + .concreteIndexNames(clusterService.state(), IndicesOptions.lenientExpandOpen(), indices); + // If concreteIndices is null or empty, don't throw exception. Detector list page will search both + // default and custom result indices to get anomaly of last 24 hours. If throw exception, detector + // list page will throw error and won't show any detector. + // If a cluster has no custom result indices, and some new non-custom-result-detector that hasn't + // finished one interval (where no default result index exists), then no result indices found. We + // will still search ".opendistro-anomaly-results*" (even these default indices don't exist) to + // return an empty SearchResponse. This search looks unnecessary, but this can make sure the + // detector list page show all detectors correctly. The other solution is to catch errors from + // frontend when search anomaly results to make sure frontend won't crash. Check this Github issue: + // https://github.com/opensearch-project/anomaly-detection-dashboards-plugin/issues/154 + if (concreteIndices != null) { + for (String index : concreteIndices) { + if (index.startsWith(CUSTOM_RESULT_INDEX_PREFIX)) { + customResultIndices.add(index); + } + } + } + } + + @VisibleForTesting + SearchRequest createSingleSearchRequest() { + // Search both custom AD result index and default result index + SearchSourceBuilder searchResultIndexBuilder = new SearchSourceBuilder(); + AggregationBuilder aggregation = new TermsAggregationBuilder(RESULT_INDEX_AGG_NAME) + .field(AnomalyDetector.RESULT_INDEX_FIELD) + .size(MAX_DETECTOR_UPPER_LIMIT); + searchResultIndexBuilder.aggregation(aggregation).size(0); + return new SearchRequest(CommonName.CONFIG_INDEX).source(searchResultIndexBuilder); + } + + @VisibleForTesting + void processSingleSearchResponse( + SearchResponse allResultIndicesResponse, + SearchRequest request, + ActionListener listener, + Set customResultIndices, + List targetIndices + ) { + Aggregations aggregations = allResultIndicesResponse.getAggregations(); + StringTerms resultIndicesAgg = aggregations.get(RESULT_INDEX_AGG_NAME); + List buckets = resultIndicesAgg.getBuckets(); + Set resultIndicesOfDetector = new HashSet<>(); + if (buckets == null) { + searchHandler.search(request, listener); + return; + } + buckets.stream().forEach(b -> resultIndicesOfDetector.add(b.getKeyAsString())); + for (String index : customResultIndices) { + if (resultIndicesOfDetector.contains(index)) { + targetIndices.add(index); + } + } + if (targetIndices.size() == 0) { + // No custom result indices used by detectors, just search default result index + searchHandler.search(request, listener); + return; + } + } + + @VisibleForTesting + MultiSearchRequest createMultiSearchRequest(List targetIndices) { + MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); + for (String index : targetIndices) { + multiSearchRequest.add(new SearchRequest(index).source(new SearchSourceBuilder().query(new MatchAllQueryBuilder()).size(0))); + } + return multiSearchRequest; + } + + @VisibleForTesting + void multiSearch( + List targetIndices, + SearchRequest request, + ActionListener listener, + boolean finalOnlyQueryCustomResultIndex, + ThreadContext.StoredContext context + ) { + if (targetIndices.size() == 0) { + // no need to make multi search + return; + } + MultiSearchRequest multiSearchRequest = createMultiSearchRequest(targetIndices); + List readableIndices = new ArrayList<>(); + if (!finalOnlyQueryCustomResultIndex) { + readableIndices.add(ALL_AD_RESULTS_INDEX_PATTERN); + } + + context.restore(); + // Send multiple search to check which index a user has permission to read. If search all indices directly, + // search request will throw exception if user has no permission to search any index. + client + .multiSearch( + multiSearchRequest, + ActionListener + .wrap( + multiSearchResponse -> { + processMultiSearchResponse(multiSearchResponse, targetIndices, readableIndices, request, listener); + }, + multiSearchException -> { + logger.error("Failed to search custom AD result indices", multiSearchException); + listener.onFailure(multiSearchException); + } + ) + ); + } + + @VisibleForTesting + void processMultiSearchResponse( + MultiSearchResponse multiSearchResponse, + List targetIndices, + List readableIndices, + SearchRequest request, + ActionListener listener + ) { + MultiSearchResponse.Item[] responses = multiSearchResponse.getResponses(); + for (int i = 0; i < responses.length; i++) { + MultiSearchResponse.Item item = responses[i]; + String indexName = targetIndices.get(i); + if (item.getFailure() == null) { + readableIndices.add(indexName); + } + } + if (readableIndices.size() == 0) { + listener.onFailure(new IllegalArgumentException("No readable custom result indices found")); + return; + } + request.indices(readableIndices.toArray(new String[0])); + searchHandler.search(request, listener); + } + + @VisibleForTesting + void searchADResultIndex( + SearchRequest request, + ActionListener listener, + boolean onlyQueryCustomResultIndex, + Set customResultIndices + ) { + SearchRequest searchResultIndex = createSingleSearchRequest(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + // Search result indices of all detectors. User may create index with same prefix of custom result index + // which not used for AD, so we should avoid searching extra indices which not used by anomaly detectors. + // Variable used in lambda expression should be final or effectively final, so copy to a final boolean and + // use the final boolean in lambda below. + boolean finalOnlyQueryCustomResultIndex = onlyQueryCustomResultIndex; + client.search(searchResultIndex, ActionListener.wrap(allResultIndicesResponse -> { + List targetIndices = new ArrayList<>(); + processSingleSearchResponse(allResultIndicesResponse, request, listener, customResultIndices, targetIndices); + multiSearch(targetIndices, request, listener, finalOnlyQueryCustomResultIndex, context); + }, e -> { + logger.error("Failed to search result indices for all detectors", e); + listener.onFailure(e); + })); + } catch (Exception e) { + logger.error(e); + listener.onFailure(e); + } + } + + @Override + protected void doExecute(Task task, SearchRequest request, ActionListener listener) { + boolean onlyQueryCustomResultIndex; + Set customResultIndices = new HashSet<>(); + + try { + onlyQueryCustomResultIndex = validateIndexAndReturnOnlyQueryCustomResult(request.indices()); + calculateCustomResultIndices(customResultIndices, request.indices()); + // If user need to query custom result index only, and that custom result index deleted. Then + // we should not search anymore. Just throw exception here. + if (onlyQueryCustomResultIndex && customResultIndices.size() == 0) { + throw new IllegalArgumentException("No custom result indices found"); + } + } catch (IllegalArgumentException exception) { + listener.onFailure(exception); + return; + } + + if (customResultIndices.size() == 0) { + // onlyQueryCustomResultIndex is false in this branch + // Search only default result index + request.indices(ALL_AD_RESULTS_INDEX_PATTERN); + searchHandler.search(request, listener); + return; + } + + searchADResultIndex(request, listener, onlyQueryCustomResultIndex, customResultIndices); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultAction.java-e b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultAction.java-e new file mode 100644 index 000000000..ee89c4179 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultAction.java-e @@ -0,0 +1,25 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; + +public class SearchTopAnomalyResultAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "result/topAnomalies"; + public static final SearchTopAnomalyResultAction INSTANCE = new SearchTopAnomalyResultAction(); + + private SearchTopAnomalyResultAction() { + super(NAME, SearchTopAnomalyResultResponse::new); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultRequest.java b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultRequest.java index 509f65ebb..8ae80077e 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultRequest.java +++ b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultRequest.java @@ -12,7 +12,7 @@ package org.opensearch.ad.transport; import static org.opensearch.action.ValidateActions.addValidationError; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; import java.time.Instant; @@ -20,8 +20,8 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.timeseries.util.ParseUtils; diff --git a/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultRequest.java-e b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultRequest.java-e new file mode 100644 index 000000000..8ae80077e --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultRequest.java-e @@ -0,0 +1,199 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.time.Instant; +import java.util.List; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.util.ParseUtils; + +/** + * Request for getting the top anomaly results for HC detectors. + *

+ * size, category field, and order are optional, and will be set to default values if left blank + */ +public class SearchTopAnomalyResultRequest extends ActionRequest { + + private static final String TASK_ID_FIELD = "task_id"; + private static final String SIZE_FIELD = "size"; + private static final String CATEGORY_FIELD_FIELD = "category_field"; + private static final String ORDER_FIELD = "order"; + private static final String START_TIME_FIELD = "start_time_ms"; + private static final String END_TIME_FIELD = "end_time_ms"; + private String detectorId; + private String taskId; + private boolean historical; + private Integer size; + private List categoryFields; + private String order; + private Instant startTime; + private Instant endTime; + + public SearchTopAnomalyResultRequest(StreamInput in) throws IOException { + super(in); + detectorId = in.readOptionalString(); + taskId = in.readOptionalString(); + historical = in.readBoolean(); + size = in.readOptionalInt(); + categoryFields = in.readOptionalStringList(); + order = in.readOptionalString(); + startTime = in.readInstant(); + endTime = in.readInstant(); + } + + public SearchTopAnomalyResultRequest( + String detectorId, + String taskId, + boolean historical, + Integer size, + List categoryFields, + String order, + Instant startTime, + Instant endTime + ) { + super(); + this.detectorId = detectorId; + this.taskId = taskId; + this.historical = historical; + this.size = size; + this.categoryFields = categoryFields; + this.order = order; + this.startTime = startTime; + this.endTime = endTime; + } + + public String getId() { + return detectorId; + } + + public String getTaskId() { + return taskId; + } + + public boolean getHistorical() { + return historical; + } + + public Integer getSize() { + return size; + } + + public List getCategoryFields() { + return categoryFields; + } + + public String getOrder() { + return order; + } + + public Instant getStartTime() { + return startTime; + } + + public Instant getEndTime() { + return endTime; + } + + public void setTaskId(String taskId) { + this.taskId = taskId; + } + + public void setSize(Integer size) { + this.size = size; + } + + public void setCategoryFields(List categoryFields) { + this.categoryFields = categoryFields; + } + + public void setOrder(String order) { + this.order = order; + } + + @SuppressWarnings("unchecked") + public static SearchTopAnomalyResultRequest parse(XContentParser parser, String detectorId, boolean historical) throws IOException { + String taskId = null; + Integer size = null; + List categoryFields = null; + String order = null; + Instant startTime = null; + Instant endTime = null; + + // "detectorId" and "historical" params come from the original API path, not in the request body + // and therefore don't need to be parsed + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case TASK_ID_FIELD: + taskId = parser.text(); + break; + case SIZE_FIELD: + size = parser.intValue(); + break; + case CATEGORY_FIELD_FIELD: + categoryFields = parser.list(); + break; + case ORDER_FIELD: + order = parser.text(); + break; + case START_TIME_FIELD: + startTime = ParseUtils.toInstant(parser); + break; + case END_TIME_FIELD: + endTime = ParseUtils.toInstant(parser); + break; + default: + break; + } + } + + // Cast category field Object list to String list + List convertedCategoryFields = (List) (List) (categoryFields); + return new SearchTopAnomalyResultRequest(detectorId, taskId, historical, size, convertedCategoryFields, order, startTime, endTime); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeOptionalString(detectorId); + out.writeOptionalString(taskId); + out.writeBoolean(historical); + out.writeOptionalInt(size); + out.writeOptionalStringCollection(categoryFields); + out.writeOptionalString(order); + out.writeInstant(startTime); + out.writeInstant(endTime); + } + + @Override + public ActionRequestValidationException validate() { + if (startTime == null || endTime == null) { + return addValidationError("Must set both start time and end time with epoch of milliseconds", null); + } + if (!startTime.isBefore(endTime)) { + return addValidationError("Start time should be before end time", null); + } + return null; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultResponse.java b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultResponse.java index eb432643c..2a33f57fa 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultResponse.java +++ b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultResponse.java @@ -16,8 +16,8 @@ import org.opensearch.action.ActionResponse; import org.opensearch.ad.model.AnomalyResultBucket; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultResponse.java-e b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultResponse.java-e new file mode 100644 index 000000000..2a33f57fa --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultResponse.java-e @@ -0,0 +1,52 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.action.ActionResponse; +import org.opensearch.ad.model.AnomalyResultBucket; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +/** + * Response for getting the top anomaly results for HC detectors + */ +public class SearchTopAnomalyResultResponse extends ActionResponse implements ToXContentObject { + private List anomalyResultBuckets; + + public SearchTopAnomalyResultResponse(StreamInput in) throws IOException { + super(in); + anomalyResultBuckets = in.readList(AnomalyResultBucket::new); + } + + public SearchTopAnomalyResultResponse(List anomalyResultBuckets) { + this.anomalyResultBuckets = anomalyResultBuckets; + } + + public List getAnomalyResultBuckets() { + return anomalyResultBuckets; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeList(anomalyResultBuckets); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.startObject().field(AnomalyResultBucket.BUCKETS_FIELD, anomalyResultBuckets).endObject(); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultTransportAction.java-e new file mode 100644 index 000000000..86ad7941a --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultTransportAction.java-e @@ -0,0 +1,579 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.indices.ADIndexManagement.ALL_AD_RESULTS_INDEX_PATTERN; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.TOP_ANOMALY_RESULT_TIMEOUT_IN_MILLIS; + +import java.time.Clock; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.PriorityQueue; +import java.util.stream.Collectors; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.model.AnomalyResultBucket; +import org.opensearch.ad.transport.handler.ADSearchHandler; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.common.Strings; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.ExistsQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.script.Script; +import org.opensearch.script.ScriptType; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.PipelineAggregatorBuilders; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregationBuilder; +import org.opensearch.search.aggregations.bucket.composite.CompositeValuesSourceBuilder; +import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; +import org.opensearch.search.aggregations.pipeline.BucketSortPipelineAggregationBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.FieldSortBuilder; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableMap; + +/** + * Transport action to fetch top anomaly results for some HC detector. Generates a + * query based on user input to fetch aggregated entity results. + */ + +// Example of a generated query aggregating over the "Carrier" category field, and sorting on max anomaly grade, using +// a historical task ID: +// +// { +// "query": { +// "bool": { +// "filter": [ +// { +// "range": { +// "data_end_time": { +// "from": "2021-09-10T07:00:00.000Z", +// "to": "2021-09-30T07:00:00.000Z", +// "include_lower": true, +// "include_upper": true, +// "boost": 1.0 +// } +// } +// }, +// { +// "range": { +// "anomaly_grade": { +// "from": 0, +// "include_lower": false, +// "include_upper": true, +// "boost": 1.0 +// } +// } +// }, +// { +// "term": { +// "task_id": { +// "value": "2AwACXwBM-RcgLq7Za87", +// "boost": 1.0 +// } +// } +// } +// ], +// "adjust_pure_negative": true, +// "boost": 1.0 +// } +// }, +// "aggregations": { +// "multi_buckets": { +// "composite": { +// "size": 100, +// "sources": [ +// { +// "Carrier": { +// "terms": { +// "script": { +// "source": """ +// String value = null; +// if (params == null || params._source == null || params._source.entity == null) { +// return ""; +// } +// for (item in params._source.entity) { +// if (item["name"] == params["categoryField"]) { +// value = item['value']; +// break; +// } +// } +// return value; +// """, +// "lang": "painless", +// "params": { +// "categoryField": "Carrier" +// } +// }, +// "missing_bucket": false, +// "order": "asc" +// } +// } +// } +// ] +// }, +// "aggregations": { +// "max_anomaly_grade": { +// "max": { +// "field": "anomaly_grade" +// } +// }, +// "bucket_sort": { +// "bucket_sort": { +// "sort": [ +// { +// "max_anomaly_grade": { +// "order": "desc" +// } +// } +// ], +// "from": 0, +// "gap_policy": "SKIP" +// } +// } +// } +// } +// } +// } + +public class SearchTopAnomalyResultTransportAction extends + HandledTransportAction { + private ADSearchHandler searchHandler; + // Number of buckets to return per page + private static final int PAGE_SIZE = 1000; + private static final OrderType DEFAULT_ORDER_TYPE = OrderType.SEVERITY; + private static final int DEFAULT_SIZE = 10; + private static final int MAX_SIZE = 1000; + private static final String defaultIndex = ALL_AD_RESULTS_INDEX_PATTERN; + private static final String COUNT_FIELD = "_count"; + private static final String BUCKET_SORT_FIELD = "bucket_sort"; + public static final String MULTI_BUCKETS_FIELD = "multi_buckets"; + private static final Logger logger = LogManager.getLogger(SearchTopAnomalyResultTransportAction.class); + private final Client client; + private Clock clock; + + public enum OrderType { + SEVERITY("severity"), + OCCURRENCE("occurrence"); + + private String name; + + OrderType(String name) { + this.name = name; + } + + public String getName() { + return name; + } + } + + @Inject + public SearchTopAnomalyResultTransportAction( + TransportService transportService, + ActionFilters actionFilters, + ADSearchHandler searchHandler, + Client client + ) { + super(SearchTopAnomalyResultAction.NAME, transportService, actionFilters, SearchTopAnomalyResultRequest::new); + this.searchHandler = searchHandler; + this.client = client; + this.clock = Clock.systemUTC(); + } + + @Override + protected void doExecute(Task task, SearchTopAnomalyResultRequest request, ActionListener listener) { + + GetAnomalyDetectorRequest getAdRequest = new GetAnomalyDetectorRequest( + request.getId(), + // The default version value used in org.opensearch.rest.action.RestActions.parseVersion() + -3L, + false, + true, + "", + "", + false, + null + ); + client.execute(GetAnomalyDetectorAction.INSTANCE, getAdRequest, ActionListener.wrap(getAdResponse -> { + // Make sure detector exists + if (getAdResponse.getDetector() == null) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "No anomaly detector found with ID %s", request.getId())); + } + + // Make sure detector is HC + List categoryFieldsFromResponse = getAdResponse.getDetector().getCategoryFields(); + if (categoryFieldsFromResponse == null || categoryFieldsFromResponse.isEmpty()) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "No category fields found for detector ID %s", request.getId()) + ); + } + + // Validating the category fields. Setting the list to be all category fields, + // unless otherwise specified + if (request.getCategoryFields() == null || request.getCategoryFields().isEmpty()) { + request.setCategoryFields(categoryFieldsFromResponse); + } else { + for (String categoryField : request.getCategoryFields()) { + if (!categoryFieldsFromResponse.contains(categoryField)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Category field %s doesn't exist for detector ID %s", categoryField, request.getId()) + ); + } + } + } + + // Validating historical tasks if historical is true. Setting the ID to the latest historical task's + // ID, unless otherwise specified + if (request.getHistorical() == true) { + ADTask historicalTask = getAdResponse.getHistoricalAdTask(); + if (historicalTask == null) { + throw new ResourceNotFoundException( + String.format(Locale.ROOT, "No historical tasks found for detector ID %s", request.getId()) + ); + } + if (Strings.isNullOrEmpty(request.getTaskId())) { + request.setTaskId(historicalTask.getTaskId()); + } + } + + // Validating the order. If nothing passed use default + OrderType orderType; + String orderString = request.getOrder(); + if (Strings.isNullOrEmpty(orderString)) { + orderType = DEFAULT_ORDER_TYPE; + } else { + if (orderString.equals(OrderType.SEVERITY.getName())) { + orderType = OrderType.SEVERITY; + } else if (orderString.equals(OrderType.OCCURRENCE.getName())) { + orderType = OrderType.OCCURRENCE; + } else { + // No valid order type was passed, throw an error + throw new IllegalArgumentException(String.format(Locale.ROOT, "Ordering by %s is not a valid option", orderString)); + } + } + request.setOrder(orderType.getName()); + + // Validating the size. If nothing passed use default + if (request.getSize() == null) { + request.setSize(DEFAULT_SIZE); + } else if (request.getSize() > MAX_SIZE) { + throw new IllegalArgumentException("Size cannot exceed " + MAX_SIZE); + } else if (request.getSize() <= 0) { + throw new IllegalArgumentException("Size must be a positive integer"); + } + + // Generating the search request which will contain the generated query + SearchRequest searchRequest = generateSearchRequest(request); + + // Adding search over any custom result indices + String rawCustomResultIndex = getAdResponse.getDetector().getCustomResultIndex(); + String customResultIndex = rawCustomResultIndex == null ? null : rawCustomResultIndex.trim(); + if (!Strings.isNullOrEmpty(customResultIndex)) { + searchRequest.indices(defaultIndex, customResultIndex); + } + + // Utilizing the existing search() from SearchHandler to handle security permissions. Both user role + // and backend role filtering is handled in there, and any error will be propagated up and + // returned as a failure in this Listener. + // This same method is used for security handling for the search results action. Since this action + // is doing fundamentally the same thing, we can reuse the security logic here. + searchHandler + .search( + searchRequest, + new TopAnomalyResultListener( + listener, + searchRequest.source(), + clock.millis() + TOP_ANOMALY_RESULT_TIMEOUT_IN_MILLIS, + request.getSize(), + orderType, + customResultIndex + ) + ); + + }, exception -> { + logger.error("Failed to get top anomaly results", exception); + listener.onFailure(exception); + })); + + } + + /** + * ActionListener class to handle bucketed search results in a paginated fashion. + * Note that the bucket_sort aggregation is a pipeline aggregation, and is executed + * after all non-pipeline aggregations (including the composite bucket aggregation). + * Because of this, the sorting is only done locally based on the buckets + * in the current page. To get around this issue, we use a max + * heap and add all results to the heap until there are no more result buckets, + * to get the globally sorted set of result buckets. + */ + class TopAnomalyResultListener implements ActionListener { + private ActionListener listener; + SearchSourceBuilder searchSourceBuilder; + private long expirationEpochMs; + private int maxResults; + private PriorityQueue topResultsHeap; + private String customResultIndex; + + TopAnomalyResultListener( + ActionListener listener, + SearchSourceBuilder searchSourceBuilder, + long expirationEpochMs, + int maxResults, + OrderType orderType, + String customResultIndex + ) { + this.listener = listener; + this.searchSourceBuilder = searchSourceBuilder; + this.expirationEpochMs = expirationEpochMs; + this.maxResults = maxResults; + this.topResultsHeap = new PriorityQueue<>(maxResults, new Comparator() { + // Sorting by ascending order of anomaly grade or doc count + @Override + public int compare(AnomalyResultBucket bucket1, AnomalyResultBucket bucket2) { + if (orderType == OrderType.SEVERITY) { + return Double.compare(bucket1.getMaxAnomalyGrade(), bucket2.getMaxAnomalyGrade()); + } else { + return Integer.compare(bucket1.getDocCount(), bucket2.getDocCount()); + } + } + }); + this.customResultIndex = customResultIndex; + } + + @Override + public void onResponse(SearchResponse response) { + try { + Aggregations aggs = response.getAggregations(); + if (aggs == null) { + // This would indicate some bug or some opensearch core changes that we are not aware of (we don't keep up-to-date with + // the large amounts of changes there). For example, they may change to if there are results return it; otherwise return + // null instead of an empty Aggregations as they currently do. + logger.warn("Unexpected null aggregation."); + listener.onResponse(new SearchTopAnomalyResultResponse(new ArrayList<>())); + return; + } + + Aggregation aggResults = aggs.get(MULTI_BUCKETS_FIELD); + if (aggResults == null) { + listener.onFailure(new IllegalArgumentException("Failed to find valid aggregation result")); + return; + } + + CompositeAggregation compositeAgg = (CompositeAggregation) aggResults; + List bucketResults = compositeAgg + .getBuckets() + .stream() + .map(bucket -> AnomalyResultBucket.createAnomalyResultBucket(bucket)) + .collect(Collectors.toList()); + + // Add all of the results to the heap, and only keep the top maxResults buckets. + // Note that the top results heap is implemented as a min heap, so by polling + // the lowest values from the heap, only the top values remain. + topResultsHeap.addAll(bucketResults); + while (topResultsHeap.size() > maxResults) { + topResultsHeap.poll(); + } + + // If afterKey is null: we've hit the end of results. Return the results + Map afterKey = compositeAgg.afterKey(); + if (afterKey == null) { + listener.onResponse(new SearchTopAnomalyResultResponse(getDescendingOrderListFromHeap(topResultsHeap))); + } else if (expirationEpochMs < clock.millis()) { + if (topResultsHeap.isEmpty()) { + listener.onFailure(new TimeSeriesException("Timed out getting all top anomaly results. Please retry later.")); + } else { + logger.info("Timed out getting all top anomaly results. Sending back partial results."); + listener.onResponse(new SearchTopAnomalyResultResponse(getDescendingOrderListFromHeap(topResultsHeap))); + } + } else { + CompositeAggregationBuilder aggBuilder = (CompositeAggregationBuilder) searchSourceBuilder + .aggregations() + .getAggregatorFactories() + .iterator() + .next(); + aggBuilder.aggregateAfter(afterKey); + + // Searching more, using an updated source with an after_key + SearchRequest searchRequest = Strings.isNullOrEmpty(customResultIndex) + ? new SearchRequest().indices(defaultIndex) + : new SearchRequest().indices(defaultIndex, customResultIndex); + searchHandler.search(searchRequest.source(searchSourceBuilder), this); + } + + } catch (Exception e) { + onFailure(e); + } + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to paginate top anomaly results", e); + listener.onFailure(e); + } + } + + /** + * Generates the entire search request to pass to the search handler + * + * @param request the request containing the all of the user-specified parameters needed to generate the request + * @return the SearchRequest to pass to the SearchHandler + */ + private SearchRequest generateSearchRequest(SearchTopAnomalyResultRequest request) { + SearchRequest searchRequest = new SearchRequest().indices(defaultIndex); + QueryBuilder query = generateQuery(request); + AggregationBuilder aggregation = generateAggregation(request); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).aggregation(aggregation); + searchRequest.source(searchSourceBuilder); + return searchRequest; + } + + /** + * Generates the query with appropriate filters on the results indices. + * If fetching real-time results: + * 1) term filter on detector_id + * 2) must_not filter on task_id (because real-time results don't have a 'task_id' field associated with them in the document) + * If fetching historical results: + * 1) term filter on the task_id + * + * @param request the request containing the necessary fields to generate the query + * @return the generated query as a QueryBuilder + */ + private QueryBuilder generateQuery(SearchTopAnomalyResultRequest request) { + BoolQueryBuilder query = new BoolQueryBuilder(); + + // Adding the date range and anomaly grade filters (needed regardless of real-time or historical) + RangeQueryBuilder dateRangeFilter = QueryBuilders + .rangeQuery(CommonName.DATA_END_TIME_FIELD) + .gte(request.getStartTime().toEpochMilli()) + .lte(request.getEndTime().toEpochMilli()); + RangeQueryBuilder anomalyGradeFilter = QueryBuilders.rangeQuery(AnomalyResult.ANOMALY_GRADE_FIELD).gt(0); + query.filter(dateRangeFilter).filter(anomalyGradeFilter); + + if (request.getHistorical() == true) { + TermQueryBuilder taskIdFilter = QueryBuilders.termQuery(CommonName.TASK_ID_FIELD, request.getTaskId()); + query.filter(taskIdFilter); + } else { + TermQueryBuilder detectorIdFilter = QueryBuilders.termQuery(AnomalyResult.DETECTOR_ID_FIELD, request.getId()); + ExistsQueryBuilder taskIdExistsFilter = QueryBuilders.existsQuery(CommonName.TASK_ID_FIELD); + query.filter(detectorIdFilter).mustNot(taskIdExistsFilter); + } + return query; + } + + /** + * Generates the composite aggregation. + * Creating a list of sources based on the set of category fields, and sorting on the returned result buckets + * + * @param request the request containing the necessary fields to generate the aggregation + * @return the generated aggregation as an AggregationBuilder + */ + private AggregationBuilder generateAggregation(SearchTopAnomalyResultRequest request) { + List> sources = new ArrayList<>(); + for (String categoryField : request.getCategoryFields()) { + Script script = getScriptForCategoryField(categoryField); + sources.add(new TermsValuesSourceBuilder(categoryField).script(script)); + } + + // Generate the max anomaly grade aggregation + AggregationBuilder maxAnomalyGradeAggregation = AggregationBuilders + .max(AnomalyResultBucket.MAX_ANOMALY_GRADE_FIELD) + .field(AnomalyResult.ANOMALY_GRADE_FIELD); + + // Generate the bucket sort aggregation (depends on order type) + String sortField = request.getOrder().equals(OrderType.SEVERITY.getName()) + ? AnomalyResultBucket.MAX_ANOMALY_GRADE_FIELD + : COUNT_FIELD; + BucketSortPipelineAggregationBuilder bucketSort = PipelineAggregatorBuilders + .bucketSort(BUCKET_SORT_FIELD, new ArrayList<>(Arrays.asList(new FieldSortBuilder(sortField).order(SortOrder.DESC)))); + + return AggregationBuilders + .composite(MULTI_BUCKETS_FIELD, sources) + .size(PAGE_SIZE) + .subAggregation(maxAnomalyGradeAggregation) + .subAggregation(bucketSort); + } + + /** + * Generates the painless script to fetch results that have an entity name matching the passed-in category field. + * + * @param categoryField the category field to be used as a source + * @return the painless script used to get all docs with entity name values matching the category field + */ + private Script getScriptForCategoryField(String categoryField) { + StringBuilder builder = new StringBuilder() + .append("String value = null;") + .append("if (params == null || params._source == null || params._source.entity == null) {") + .append("return \"\"") + .append("}") + .append("for (item in params._source.entity) {") + .append("if (item[\"name\"] == params[\"categoryField\"]) {") + .append("value = item['value'];") + .append("break;") + .append("}") + .append("}") + .append("return value;"); + + // The last argument contains the K/V pair to inject the categoryField value into the script + return new Script( + ScriptType.INLINE, + "painless", + builder.toString(), + Collections.emptyMap(), + ImmutableMap.of("categoryField", categoryField) + ); + } + + /** + * Creates a descending-ordered List from a min heap. + * + * @param minHeap a min heap + * @return an ordered List containing all of the elements in the heap + */ + private List getDescendingOrderListFromHeap(PriorityQueue minHeap) { + List topResultsHeapAsList = new ArrayList<>(); + while (!minHeap.isEmpty()) { + topResultsHeapAsList.add(minHeap.poll()); + } + // Need to reverse the list, since polling from a min heap + // will return results in ascending order + Collections.reverse(topResultsHeapAsList); + return topResultsHeapAsList; + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorAction.java-e b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorAction.java-e new file mode 100644 index 000000000..3c1f53d9d --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorAction.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; + +public class StatsAnomalyDetectorAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/stats"; + public static final StatsAnomalyDetectorAction INSTANCE = new StatsAnomalyDetectorAction(); + + private StatsAnomalyDetectorAction() { + super(NAME, StatsAnomalyDetectorResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorResponse.java b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorResponse.java index eeecff7da..4a233fc62 100644 --- a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorResponse.java +++ b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorResponse.java @@ -15,8 +15,8 @@ import org.opensearch.action.ActionResponse; import org.opensearch.ad.stats.ADStatsResponse; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorResponse.java-e b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorResponse.java-e new file mode 100644 index 000000000..4a233fc62 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorResponse.java-e @@ -0,0 +1,49 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionResponse; +import org.opensearch.ad.stats.ADStatsResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +public class StatsAnomalyDetectorResponse extends ActionResponse implements ToXContentObject { + private ADStatsResponse adStatsResponse; + + public StatsAnomalyDetectorResponse(StreamInput in) throws IOException { + super(in); + adStatsResponse = new ADStatsResponse(in); + } + + public StatsAnomalyDetectorResponse(ADStatsResponse adStatsResponse) { + this.adStatsResponse = adStatsResponse; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + adStatsResponse.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + adStatsResponse.toXContent(builder, params); + return builder; + } + + protected ADStatsResponse getAdStatsResponse() { + return adStatsResponse; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportAction.java index 16fdfd611..caf4bd42a 100644 --- a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportAction.java @@ -35,7 +35,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.rest.RestStatus; +import org.opensearch.core.rest.RestStatus; import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.bucket.terms.StringTerms; import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; diff --git a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportAction.java-e new file mode 100644 index 000000000..caf4bd42a --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportAction.java-e @@ -0,0 +1,213 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_GET_STATS; +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorType; +import org.opensearch.ad.stats.ADStats; +import org.opensearch.ad.stats.ADStatsResponse; +import org.opensearch.ad.util.MultiResponsesDelegateActionListener; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.bucket.terms.StringTerms; +import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.transport.TransportService; + +public class StatsAnomalyDetectorTransportAction extends HandledTransportAction { + public static final String DETECTOR_TYPE_AGG = "detector_type_agg"; + private final Logger logger = LogManager.getLogger(StatsAnomalyDetectorTransportAction.class); + + private final Client client; + private final ADStats adStats; + private final ClusterService clusterService; + + @Inject + public StatsAnomalyDetectorTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ADStats adStats, + ClusterService clusterService + + ) { + super(StatsAnomalyDetectorAction.NAME, transportService, actionFilters, ADStatsRequest::new); + this.client = client; + this.adStats = adStats; + this.clusterService = clusterService; + } + + @Override + protected void doExecute(Task task, ADStatsRequest request, ActionListener actionListener) { + ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_GET_STATS); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + getStats(client, listener, request); + } catch (Exception e) { + logger.error(e); + listener.onFailure(e); + } + } + + /** + * Make the 2 requests to get the node and cluster statistics + * + * @param client Client + * @param listener Listener to send response + * @param adStatsRequest Request containing stats to be retrieved + */ + private void getStats(Client client, ActionListener listener, ADStatsRequest adStatsRequest) { + // Use MultiResponsesDelegateActionListener to execute 2 async requests and create the response once they finish + MultiResponsesDelegateActionListener delegateListener = new MultiResponsesDelegateActionListener<>( + getRestStatsListener(listener), + 2, + "Unable to return AD Stats", + false + ); + + getClusterStats(client, delegateListener, adStatsRequest); + getNodeStats(client, delegateListener, adStatsRequest); + } + + /** + * Listener sends response once Node Stats and Cluster Stats are gathered + * + * @param listener Listener to send response + * @return ActionListener for ADStatsResponse + */ + private ActionListener getRestStatsListener(ActionListener listener) { + return ActionListener + .wrap( + adStatsResponse -> { listener.onResponse(new StatsAnomalyDetectorResponse(adStatsResponse)); }, + exception -> listener.onFailure(new OpenSearchStatusException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)) + ); + } + + /** + * Make async request to get the number of detectors in AnomalyDetector.ANOMALY_DETECTORS_INDEX if necessary + * and, onResponse, gather the cluster statistics + * + * @param client Client + * @param listener MultiResponsesDelegateActionListener to be used once both requests complete + * @param adStatsRequest Request containing stats to be retrieved + */ + private void getClusterStats( + Client client, + MultiResponsesDelegateActionListener listener, + ADStatsRequest adStatsRequest + ) { + ADStatsResponse adStatsResponse = new ADStatsResponse(); + if ((adStatsRequest.getStatsToBeRetrieved().contains(StatNames.DETECTOR_COUNT.getName()) + || adStatsRequest.getStatsToBeRetrieved().contains(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName()) + || adStatsRequest.getStatsToBeRetrieved().contains(StatNames.MULTI_ENTITY_DETECTOR_COUNT.getName())) + && clusterService.state().getRoutingTable().hasIndex(CommonName.CONFIG_INDEX)) { + + TermsAggregationBuilder termsAgg = AggregationBuilders.terms(DETECTOR_TYPE_AGG).field(AnomalyDetector.DETECTOR_TYPE_FIELD); + SearchRequest request = new SearchRequest() + .indices(CommonName.CONFIG_INDEX) + .source(new SearchSourceBuilder().aggregation(termsAgg).size(0).trackTotalHits(true)); + + client.search(request, ActionListener.wrap(r -> { + StringTerms aggregation = r.getAggregations().get(DETECTOR_TYPE_AGG); + List buckets = aggregation.getBuckets(); + long totalDetectors = r.getHits().getTotalHits().value; + long totalSingleEntityDetectors = 0; + long totalMultiEntityDetectors = 0; + for (StringTerms.Bucket b : buckets) { + if (AnomalyDetectorType.SINGLE_ENTITY.name().equals(b.getKeyAsString()) + || AnomalyDetectorType.REALTIME_SINGLE_ENTITY.name().equals(b.getKeyAsString()) + || AnomalyDetectorType.HISTORICAL_SINGLE_ENTITY.name().equals(b.getKeyAsString())) { + totalSingleEntityDetectors += b.getDocCount(); + } + if (AnomalyDetectorType.MULTI_ENTITY.name().equals(b.getKeyAsString()) + || AnomalyDetectorType.REALTIME_MULTI_ENTITY.name().equals(b.getKeyAsString()) + || AnomalyDetectorType.HISTORICAL_MULTI_ENTITY.name().equals(b.getKeyAsString())) { + totalMultiEntityDetectors += b.getDocCount(); + } + } + if (adStatsRequest.getStatsToBeRetrieved().contains(StatNames.DETECTOR_COUNT.getName())) { + adStats.getStat(StatNames.DETECTOR_COUNT.getName()).setValue(totalDetectors); + } + if (adStatsRequest.getStatsToBeRetrieved().contains(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName())) { + adStats.getStat(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName()).setValue(totalSingleEntityDetectors); + } + if (adStatsRequest.getStatsToBeRetrieved().contains(StatNames.MULTI_ENTITY_DETECTOR_COUNT.getName())) { + adStats.getStat(StatNames.MULTI_ENTITY_DETECTOR_COUNT.getName()).setValue(totalMultiEntityDetectors); + } + adStatsResponse.setClusterStats(getClusterStatsMap(adStatsRequest)); + listener.onResponse(adStatsResponse); + }, e -> listener.onFailure(e))); + } else { + adStatsResponse.setClusterStats(getClusterStatsMap(adStatsRequest)); + listener.onResponse(adStatsResponse); + } + } + + /** + * Collect Cluster Stats into map to be retrieved + * + * @param adStatsRequest Request containing stats to be retrieved + * @return Map containing Cluster Stats + */ + private Map getClusterStatsMap(ADStatsRequest adStatsRequest) { + Map clusterStats = new HashMap<>(); + Set statsToBeRetrieved = adStatsRequest.getStatsToBeRetrieved(); + adStats + .getClusterStats() + .entrySet() + .stream() + .filter(s -> statsToBeRetrieved.contains(s.getKey())) + .forEach(s -> clusterStats.put(s.getKey(), s.getValue().getValue())); + return clusterStats; + } + + /** + * Make async request to get the Anomaly Detection statistics from each node and, onResponse, set the + * ADStatsNodesResponse field of ADStatsResponse + * + * @param client Client + * @param listener MultiResponsesDelegateActionListener to be used once both requests complete + * @param adStatsRequest Request containing stats to be retrieved + */ + private void getNodeStats( + Client client, + MultiResponsesDelegateActionListener listener, + ADStatsRequest adStatsRequest + ) { + client.execute(ADStatsNodesAction.INSTANCE, adStatsRequest, ActionListener.wrap(adStatsResponse -> { + ADStatsResponse restADStatsResponse = new ADStatsResponse(); + restADStatsResponse.setADStatsNodesResponse(adStatsResponse); + listener.onResponse(restADStatsResponse); + }, listener::onFailure)); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/StopDetectorAction.java-e b/src/main/java/org/opensearch/ad/transport/StopDetectorAction.java-e new file mode 100644 index 000000000..5c7182920 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/StopDetectorAction.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; + +public class StopDetectorAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detector/stop"; + public static final StopDetectorAction INSTANCE = new StopDetectorAction(); + + private StopDetectorAction() { + super(NAME, StopDetectorResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/StopDetectorRequest.java b/src/main/java/org/opensearch/ad/transport/StopDetectorRequest.java index 787bb851a..71563a2cd 100644 --- a/src/main/java/org/opensearch/ad/transport/StopDetectorRequest.java +++ b/src/main/java/org/opensearch/ad/transport/StopDetectorRequest.java @@ -21,11 +21,11 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.common.io.stream.InputStreamStreamInput; -import org.opensearch.common.io.stream.OutputStreamStreamOutput; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/transport/StopDetectorRequest.java-e b/src/main/java/org/opensearch/ad/transport/StopDetectorRequest.java-e new file mode 100644 index 000000000..9a854ded2 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/StopDetectorRequest.java-e @@ -0,0 +1,94 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +public class StopDetectorRequest extends ActionRequest implements ToXContentObject { + + private String adID; + + public StopDetectorRequest() {} + + public StopDetectorRequest(StreamInput in) throws IOException { + super(in); + this.adID = in.readString(); + } + + public StopDetectorRequest(String adID) { + super(); + this.adID = adID; + } + + public String getAdID() { + return adID; + } + + public StopDetectorRequest adID(String adID) { + this.adID = adID; + return this; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(adID); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (Strings.isEmpty(adID)) { + validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + } + return validationException; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ADCommonName.ID_JSON_KEY, adID); + builder.endObject(); + return builder; + } + + public static StopDetectorRequest fromActionRequest(final ActionRequest actionRequest) { + if (actionRequest instanceof StopDetectorRequest) { + return (StopDetectorRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new StopDetectorRequest(input); + } + } catch (IOException e) { + throw new IllegalArgumentException("failed to parse ActionRequest into StopDetectorRequest", e); + } + } +} diff --git a/src/main/java/org/opensearch/ad/transport/StopDetectorResponse.java b/src/main/java/org/opensearch/ad/transport/StopDetectorResponse.java index 3d7598176..b3606b918 100644 --- a/src/main/java/org/opensearch/ad/transport/StopDetectorResponse.java +++ b/src/main/java/org/opensearch/ad/transport/StopDetectorResponse.java @@ -16,10 +16,10 @@ import java.io.IOException; import org.opensearch.action.ActionResponse; -import org.opensearch.common.io.stream.InputStreamStreamInput; -import org.opensearch.common.io.stream.OutputStreamStreamOutput; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/transport/StopDetectorResponse.java-e b/src/main/java/org/opensearch/ad/transport/StopDetectorResponse.java-e new file mode 100644 index 000000000..b3606b918 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/StopDetectorResponse.java-e @@ -0,0 +1,70 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; + +import org.opensearch.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +public class StopDetectorResponse extends ActionResponse implements ToXContentObject { + public static final String SUCCESS_JSON_KEY = "success"; + private boolean success; + + public StopDetectorResponse(boolean success) { + this.success = success; + } + + public StopDetectorResponse(StreamInput in) throws IOException { + super(in); + success = in.readBoolean(); + } + + public boolean success() { + return success; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeBoolean(success); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(SUCCESS_JSON_KEY, success); + builder.endObject(); + return builder; + } + + public static StopDetectorResponse fromActionResponse(final ActionResponse actionResponse) { + if (actionResponse instanceof StopDetectorResponse) { + return (StopDetectorResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (InputStreamStreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new StopDetectorResponse(input); + } + } catch (IOException e) { + throw new IllegalArgumentException("failed to parse ActionResponse into StopDetectorResponse", e); + } + } +} diff --git a/src/main/java/org/opensearch/ad/transport/StopDetectorTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/StopDetectorTransportAction.java-e new file mode 100644 index 000000000..f84d4114e --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/StopDetectorTransportAction.java-e @@ -0,0 +1,82 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_STOP_DETECTOR; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.inject.Inject; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.common.exception.InternalFailure; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.transport.TransportService; + +public class StopDetectorTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(StopDetectorTransportAction.class); + + private final Client client; + private final DiscoveryNodeFilterer nodeFilter; + + @Inject + public StopDetectorTransportAction( + TransportService transportService, + DiscoveryNodeFilterer nodeFilter, + ActionFilters actionFilters, + Client client + ) { + super(StopDetectorAction.NAME, transportService, actionFilters, StopDetectorRequest::new); + this.client = client; + this.nodeFilter = nodeFilter; + } + + @Override + protected void doExecute(Task task, ActionRequest actionRequest, ActionListener listener) { + StopDetectorRequest request = StopDetectorRequest.fromActionRequest(actionRequest); + String adID = request.getAdID(); + try { + DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); + DeleteModelRequest modelDeleteRequest = new DeleteModelRequest(adID, dataNodes); + client.execute(DeleteModelAction.INSTANCE, modelDeleteRequest, ActionListener.wrap(response -> { + if (response.hasFailures()) { + LOG.warn("Cannot delete all models of detector {}", adID); + for (FailedNodeException failedNodeException : response.failures()) { + LOG.warn("Deleting models of node has exception", failedNodeException); + } + // if customers are using an updated detector and we haven't deleted old + // checkpoints, customer would have trouble + listener.onResponse(new StopDetectorResponse(false)); + } else { + LOG.info("models of detector {} get deleted", adID); + listener.onResponse(new StopDetectorResponse(true)); + } + }, exception -> { + LOG.error(new ParameterizedMessage("Deletion of detector [{}] has exception.", adID), exception); + listener.onResponse(new StopDetectorResponse(false)); + })); + } catch (Exception e) { + LOG.error(FAIL_TO_STOP_DETECTOR + " " + adID, e); + Throwable cause = ExceptionsHelper.unwrapCause(e); + listener.onFailure(new InternalFailure(adID, FAIL_TO_STOP_DETECTOR, cause)); + } + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ThresholdResultAction.java-e b/src/main/java/org/opensearch/ad/transport/ThresholdResultAction.java-e new file mode 100644 index 000000000..1561c08dc --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ThresholdResultAction.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; + +public class ThresholdResultAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "threshold/result"; + public static final ThresholdResultAction INSTANCE = new ThresholdResultAction(); + + private ThresholdResultAction() { + super(NAME, ThresholdResultResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/ThresholdResultRequest.java b/src/main/java/org/opensearch/ad/transport/ThresholdResultRequest.java index 72751bf9a..d9310ae31 100644 --- a/src/main/java/org/opensearch/ad/transport/ThresholdResultRequest.java +++ b/src/main/java/org/opensearch/ad/transport/ThresholdResultRequest.java @@ -19,9 +19,9 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.constant.CommonName; diff --git a/src/main/java/org/opensearch/ad/transport/ThresholdResultRequest.java-e b/src/main/java/org/opensearch/ad/transport/ThresholdResultRequest.java-e new file mode 100644 index 000000000..083f45c0a --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ThresholdResultRequest.java-e @@ -0,0 +1,94 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonName; + +public class ThresholdResultRequest extends ActionRequest implements ToXContentObject { + private String adID; + private String modelID; + private double rcfScore; + + public ThresholdResultRequest(StreamInput in) throws IOException { + super(in); + adID = in.readString(); + modelID = in.readString(); + rcfScore = in.readDouble(); + } + + public ThresholdResultRequest(String adID, String modelID, double rcfScore) { + super(); + this.adID = adID; + this.modelID = modelID; + this.rcfScore = rcfScore; + } + + public double getRCFScore() { + return rcfScore; + } + + public String getAdID() { + return adID; + } + + public String getModelID() { + return modelID; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(adID); + out.writeString(modelID); + out.writeDouble(rcfScore); + } + + /** + * Verify request parameter corresponds to our understanding of the data. + * We don't verify whether rcfScore is less than 0 or not because this cannot happen. + */ + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (Strings.isEmpty(adID)) { + validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + } + if (Strings.isEmpty(modelID)) { + validationException = addValidationError(ADCommonMessages.MODEL_ID_MISSING_MSG, validationException); + } + + return validationException; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ADCommonName.ID_JSON_KEY, adID); + builder.field(CommonName.MODEL_ID_FIELD, modelID); + builder.field(ADCommonName.RCF_SCORE_JSON_KEY, rcfScore); + builder.endObject(); + return builder; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ThresholdResultResponse.java b/src/main/java/org/opensearch/ad/transport/ThresholdResultResponse.java index a6cbe06e3..4e117bc3a 100644 --- a/src/main/java/org/opensearch/ad/transport/ThresholdResultResponse.java +++ b/src/main/java/org/opensearch/ad/transport/ThresholdResultResponse.java @@ -15,8 +15,8 @@ import org.opensearch.action.ActionResponse; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/transport/ThresholdResultResponse.java-e b/src/main/java/org/opensearch/ad/transport/ThresholdResultResponse.java-e new file mode 100644 index 000000000..4e117bc3a --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ThresholdResultResponse.java-e @@ -0,0 +1,61 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionResponse; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +public class ThresholdResultResponse extends ActionResponse implements ToXContentObject { + private double anomalyGrade; + private double confidence; + + public ThresholdResultResponse(double anomalyGrade, double confidence) { + this.anomalyGrade = anomalyGrade; + this.confidence = confidence; + } + + public ThresholdResultResponse(StreamInput in) throws IOException { + super(in); + anomalyGrade = in.readDouble(); + confidence = in.readDouble(); + } + + public double getAnomalyGrade() { + return anomalyGrade; + } + + public double getConfidence() { + return confidence; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(anomalyGrade); + out.writeDouble(confidence); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ADCommonName.ANOMALY_GRADE_JSON_KEY, anomalyGrade); + builder.field(ADCommonName.CONFIDENCE_JSON_KEY, confidence); + builder.endObject(); + return builder; + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/ThresholdResultTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/ThresholdResultTransportAction.java-e new file mode 100644 index 000000000..9e292b676 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ThresholdResultTransportAction.java-e @@ -0,0 +1,57 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.common.inject.Inject; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +public class ThresholdResultTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(ThresholdResultTransportAction.class); + private ModelManager manager; + + @Inject + public ThresholdResultTransportAction(ActionFilters actionFilters, TransportService transportService, ModelManager manager) { + super(ThresholdResultAction.NAME, transportService, actionFilters, ThresholdResultRequest::new); + this.manager = manager; + } + + @Override + protected void doExecute(Task task, ThresholdResultRequest request, ActionListener listener) { + + try { + LOG.info("Serve threshold request for {}", request.getModelID()); + manager + .getThresholdingResult( + request.getAdID(), + request.getModelID(), + request.getRCFScore(), + ActionListener + .wrap( + result -> listener.onResponse(new ThresholdResultResponse(result.getGrade(), result.getConfidence())), + exception -> listener.onFailure(exception) + ) + ); + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorAction.java-e b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorAction.java-e new file mode 100644 index 000000000..432166ac2 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorAction.java-e @@ -0,0 +1,25 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; + +public class ValidateAnomalyDetectorAction extends ActionType { + + public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/validate"; + public static final ValidateAnomalyDetectorAction INSTANCE = new ValidateAnomalyDetectorAction(); + + private ValidateAnomalyDetectorAction() { + super(NAME, ValidateAnomalyDetectorResponse::new); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequest.java b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequest.java index 760ed3539..3ee1f0a6e 100644 --- a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequest.java +++ b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequest.java @@ -16,9 +16,9 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; public class ValidateAnomalyDetectorRequest extends ActionRequest { diff --git a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequest.java-e b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequest.java-e new file mode 100644 index 000000000..7abae426e --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequest.java-e @@ -0,0 +1,97 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.common.unit.TimeValue; + +public class ValidateAnomalyDetectorRequest extends ActionRequest { + + private final AnomalyDetector detector; + private final String validationType; + private final Integer maxSingleEntityAnomalyDetectors; + private final Integer maxMultiEntityAnomalyDetectors; + private final Integer maxAnomalyFeatures; + private final TimeValue requestTimeout; + + public ValidateAnomalyDetectorRequest(StreamInput in) throws IOException { + super(in); + detector = new AnomalyDetector(in); + validationType = in.readString(); + maxSingleEntityAnomalyDetectors = in.readInt(); + maxMultiEntityAnomalyDetectors = in.readInt(); + maxAnomalyFeatures = in.readInt(); + requestTimeout = in.readTimeValue(); + } + + public ValidateAnomalyDetectorRequest( + AnomalyDetector detector, + String validationType, + Integer maxSingleEntityAnomalyDetectors, + Integer maxMultiEntityAnomalyDetectors, + Integer maxAnomalyFeatures, + TimeValue requestTimeout + ) { + this.detector = detector; + this.validationType = validationType; + this.maxSingleEntityAnomalyDetectors = maxSingleEntityAnomalyDetectors; + this.maxMultiEntityAnomalyDetectors = maxMultiEntityAnomalyDetectors; + this.maxAnomalyFeatures = maxAnomalyFeatures; + this.requestTimeout = requestTimeout; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + detector.writeTo(out); + out.writeString(validationType); + out.writeInt(maxSingleEntityAnomalyDetectors); + out.writeInt(maxMultiEntityAnomalyDetectors); + out.writeInt(maxAnomalyFeatures); + out.writeTimeValue(requestTimeout); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + public AnomalyDetector getDetector() { + return detector; + } + + public String getValidationType() { + return validationType; + } + + public Integer getMaxSingleEntityAnomalyDetectors() { + return maxSingleEntityAnomalyDetectors; + } + + public Integer getMaxMultiEntityAnomalyDetectors() { + return maxMultiEntityAnomalyDetectors; + } + + public Integer getMaxAnomalyFeatures() { + return maxAnomalyFeatures; + } + + public TimeValue getRequestTimeout() { + return requestTimeout; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponse.java b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponse.java index 621bc63e3..407b4c871 100644 --- a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponse.java +++ b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponse.java @@ -15,8 +15,8 @@ import org.opensearch.action.ActionResponse; import org.opensearch.ad.model.DetectorValidationIssue; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponse.java-e b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponse.java-e new file mode 100644 index 000000000..407b4c871 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponse.java-e @@ -0,0 +1,64 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionResponse; +import org.opensearch.ad.model.DetectorValidationIssue; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +public class ValidateAnomalyDetectorResponse extends ActionResponse implements ToXContentObject { + private DetectorValidationIssue issue; + + public DetectorValidationIssue getIssue() { + return issue; + } + + public ValidateAnomalyDetectorResponse(DetectorValidationIssue issue) { + this.issue = issue; + } + + public ValidateAnomalyDetectorResponse(StreamInput in) throws IOException { + super(in); + if (in.readBoolean()) { + issue = new DetectorValidationIssue(in); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (issue != null) { + out.writeBoolean(true); + issue.writeTo(out); + } else { + out.writeBoolean(false); + } + } + + public XContentBuilder toXContent(XContentBuilder builder) throws IOException { + return toXContent(builder, ToXContent.EMPTY_PARAMS); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + if (issue != null) { + xContentBuilder.field(issue.getAspect().getName(), issue); + } + return xContentBuilder.endObject(); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java-e b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java-e new file mode 100644 index 000000000..ecd0ca07c --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java-e @@ -0,0 +1,255 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES; +import static org.opensearch.timeseries.util.ParseUtils.checkFilterByBackendRoles; +import static org.opensearch.timeseries.util.ParseUtils.getUserContext; + +import java.time.Clock; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.DetectorValidationIssue; +import org.opensearch.ad.rest.handler.ValidateAnomalyDetectorActionHandler; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.rest.RestRequest; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.transport.TransportService; + +public class ValidateAnomalyDetectorTransportAction extends + HandledTransportAction { + private static final Logger logger = LogManager.getLogger(ValidateAnomalyDetectorTransportAction.class); + + private final Client client; + private final SecurityClientUtil clientUtil; + private final ClusterService clusterService; + private final NamedXContentRegistry xContentRegistry; + private final ADIndexManagement anomalyDetectionIndices; + private final SearchFeatureDao searchFeatureDao; + private volatile Boolean filterByEnabled; + private Clock clock; + private Settings settings; + + @Inject + public ValidateAnomalyDetectorTransportAction( + Client client, + SecurityClientUtil clientUtil, + ClusterService clusterService, + NamedXContentRegistry xContentRegistry, + Settings settings, + ADIndexManagement anomalyDetectionIndices, + ActionFilters actionFilters, + TransportService transportService, + SearchFeatureDao searchFeatureDao + ) { + super(ValidateAnomalyDetectorAction.NAME, transportService, actionFilters, ValidateAnomalyDetectorRequest::new); + this.client = client; + this.clientUtil = clientUtil; + this.clusterService = clusterService; + this.xContentRegistry = xContentRegistry; + this.anomalyDetectionIndices = anomalyDetectionIndices; + this.filterByEnabled = AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); + this.searchFeatureDao = searchFeatureDao; + this.clock = Clock.systemUTC(); + this.settings = settings; + } + + @Override + protected void doExecute(Task task, ValidateAnomalyDetectorRequest request, ActionListener listener) { + User user = getUserContext(client); + AnomalyDetector anomalyDetector = request.getDetector(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + resolveUserAndExecute(user, listener, () -> validateExecute(request, user, context, listener)); + } catch (Exception e) { + logger.error(e); + listener.onFailure(e); + } + } + + private void resolveUserAndExecute( + User requestedUser, + ActionListener listener, + ExecutorFunction function + ) { + try { + // Check if user has backend roles + // When filter by is enabled, block users validating detectors who do not have backend roles. + if (filterByEnabled && !checkFilterByBackendRoles(requestedUser, listener)) { + return; + } + // Validate Detector + function.execute(); + } catch (Exception e) { + listener.onFailure(e); + } + } + + private void validateExecute( + ValidateAnomalyDetectorRequest request, + User user, + ThreadContext.StoredContext storedContext, + ActionListener listener + ) { + storedContext.restore(); + AnomalyDetector detector = request.getDetector(); + ActionListener validateListener = ActionListener.wrap(response -> { + logger.debug("Result of validation process " + response); + // forcing response to be empty + listener.onResponse(new ValidateAnomalyDetectorResponse((DetectorValidationIssue) null)); + }, exception -> { + if (exception instanceof ValidationException) { + // ADValidationException is converted as validation issues returned as response to user + DetectorValidationIssue issue = parseADValidationException((ValidationException) exception); + listener.onResponse(new ValidateAnomalyDetectorResponse(issue)); + return; + } + logger.error(exception); + listener.onFailure(exception); + }); + checkIndicesAndExecute(detector.getIndices(), () -> { + ValidateAnomalyDetectorActionHandler handler = new ValidateAnomalyDetectorActionHandler( + clusterService, + client, + clientUtil, + validateListener, + anomalyDetectionIndices, + detector, + request.getRequestTimeout(), + request.getMaxSingleEntityAnomalyDetectors(), + request.getMaxMultiEntityAnomalyDetectors(), + request.getMaxAnomalyFeatures(), + RestRequest.Method.POST, + xContentRegistry, + user, + searchFeatureDao, + request.getValidationType(), + clock, + settings + ); + try { + handler.start(); + } catch (Exception exception) { + String errorMessage = String + .format(Locale.ROOT, "Unknown exception caught while validating detector %s", request.getDetector()); + logger.error(errorMessage, exception); + listener.onFailure(exception); + } + }, listener); + } + + protected DetectorValidationIssue parseADValidationException(ValidationException exception) { + String originalErrorMessage = exception.getMessage(); + String errorMessage = ""; + Map subIssues = null; + IntervalTimeConfiguration intervalSuggestion = exception.getIntervalSuggestion(); + switch (exception.getType()) { + case FEATURE_ATTRIBUTES: + int firstLeftBracketIndex = originalErrorMessage.indexOf("["); + int lastRightBracketIndex = originalErrorMessage.lastIndexOf("]"); + if (firstLeftBracketIndex != -1) { + // if feature issue messages are between square brackets like + // [Feature has issue: A, Feature has issue: B] + errorMessage = originalErrorMessage.substring(firstLeftBracketIndex + 1, lastRightBracketIndex); + subIssues = getFeatureSubIssuesFromErrorMessage(errorMessage); + } else { + // features having issue like over max feature limit, duplicate feature name, etc. + errorMessage = originalErrorMessage; + } + break; + case NAME: + case CATEGORY: + case DETECTION_INTERVAL: + case FILTER_QUERY: + case TIMEFIELD_FIELD: + case SHINGLE_SIZE_FIELD: + case WINDOW_DELAY: + case RESULT_INDEX: + case GENERAL_SETTINGS: + case AGGREGATION: + case TIMEOUT: + case INDICES: + errorMessage = originalErrorMessage; + break; + } + return new DetectorValidationIssue(exception.getAspect(), exception.getType(), errorMessage, subIssues, intervalSuggestion); + } + + // Example of method output: + // String input:Feature has invalid query returning empty aggregated data: average_total_rev, Feature has invalid query causing runtime + // exception: average_total_rev-2 + // output: "sub_issues": { + // "average_total_rev": "Feature has invalid query returning empty aggregated data", + // "average_total_rev-2": "Feature has invalid query causing runtime exception" + // } + private Map getFeatureSubIssuesFromErrorMessage(String errorMessage) { + Map result = new HashMap<>(); + String[] subIssueMessagesSuffix = errorMessage.split(", "); + for (int i = 0; i < subIssueMessagesSuffix.length; i++) { + result.put(subIssueMessagesSuffix[i].split(": ")[1], subIssueMessagesSuffix[i].split(": ")[0]); + } + return result; + } + + private void checkIndicesAndExecute( + List indices, + ExecutorFunction function, + ActionListener listener + ) { + SearchRequest searchRequest = new SearchRequest() + .indices(indices.toArray(new String[0])) + .source(new SearchSourceBuilder().size(1).query(QueryBuilders.matchAllQuery())); + client.search(searchRequest, ActionListener.wrap(r -> function.execute(), e -> { + if (e instanceof IndexNotFoundException) { + // IndexNotFoundException is converted to a ADValidationException that gets + // parsed to a DetectorValidationIssue that is returned to + // the user as a response indicating index doesn't exist + DetectorValidationIssue issue = parseADValidationException( + new ValidationException(ADCommonMessages.INDEX_NOT_FOUND, ValidationIssueType.INDICES, ValidationAspect.DETECTOR) + ); + listener.onResponse(new ValidateAnomalyDetectorResponse(issue)); + return; + } + logger.error(e); + listener.onFailure(e); + })); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/handler/ADSearchHandler.java-e b/src/main/java/org/opensearch/ad/transport/handler/ADSearchHandler.java-e new file mode 100644 index 000000000..4831eae88 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/handler/ADSearchHandler.java-e @@ -0,0 +1,84 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport.handler; + +import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_SEARCH; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES; +import static org.opensearch.timeseries.util.ParseUtils.addUserBackendRolesFilter; +import static org.opensearch.timeseries.util.ParseUtils.getUserContext; +import static org.opensearch.timeseries.util.ParseUtils.isAdmin; +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; + +/** + * Handle general search request, check user role and return search response. + */ +public class ADSearchHandler { + private final Logger logger = LogManager.getLogger(ADSearchHandler.class); + private final Client client; + private volatile Boolean filterEnabled; + + public ADSearchHandler(Settings settings, ClusterService clusterService, Client client) { + this.client = client; + filterEnabled = AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FILTER_BY_BACKEND_ROLES, it -> filterEnabled = it); + } + + /** + * Validate user role, add backend role filter if filter enabled + * and execute search. + * + * @param request search request + * @param actionListener action listerner + */ + public void search(SearchRequest request, ActionListener actionListener) { + User user = getUserContext(client); + ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_SEARCH); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + validateRole(request, user, listener); + } catch (Exception e) { + logger.error(e); + listener.onFailure(e); + } + } + + private void validateRole(SearchRequest request, User user, ActionListener listener) { + if (user == null || !filterEnabled || isAdmin(user)) { + // Case 1: user == null when 1. Security is disabled. 2. When user is super-admin + // Case 2: If Security is enabled and filter is disabled, proceed with search as + // user is already authenticated to hit this API. + // case 3: user is admin which means we don't have to check backend role filtering + client.search(request, listener); + } else { + // Security is enabled, filter is enabled and user isn't admin + try { + addUserBackendRolesFilter(user, request.source()); + logger.debug("Filtering result by " + user.getBackendRoles()); + client.search(request, listener); + } catch (Exception e) { + listener.onFailure(e); + } + } + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/handler/AnomalyIndexHandler.java-e b/src/main/java/org/opensearch/ad/transport/handler/AnomalyIndexHandler.java-e new file mode 100644 index 000000000..371640ad2 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/handler/AnomalyIndexHandler.java-e @@ -0,0 +1,233 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport.handler; + +import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; +import static org.opensearch.timeseries.constant.CommonMessages.CAN_NOT_FIND_RESULT_INDEX; + +import java.util.Iterator; +import java.util.Locale; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.bulk.BackoffPolicy; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.util.BulkUtil; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.ad.util.IndexUtils; +import org.opensearch.client.Client; +import org.opensearch.cluster.block.ClusterBlockLevel; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.util.RestHandlerUtils; + +public class AnomalyIndexHandler { + private static final Logger LOG = LogManager.getLogger(AnomalyIndexHandler.class); + static final String FAIL_TO_SAVE_ERR_MSG = "Fail to save %s: "; + static final String SUCCESS_SAVING_MSG = "Succeed in saving %s"; + static final String CANNOT_SAVE_ERR_MSG = "Cannot save %s due to write block."; + static final String RETRY_SAVING_ERR_MSG = "Retry in saving %s: "; + + protected final Client client; + + protected final ThreadPool threadPool; + protected final BackoffPolicy savingBackoffPolicy; + protected final String indexName; + protected final ADIndexManagement anomalyDetectionIndices; + // whether save to a specific doc id or not. False by default. + protected boolean fixedDoc; + protected final ClientUtil clientUtil; + protected final IndexUtils indexUtils; + protected final ClusterService clusterService; + + /** + * Abstract class for index operation. + * + * @param client client to OpenSearch query + * @param settings accessor for node settings. + * @param threadPool used to invoke specific threadpool to execute + * @param indexName name of index to save to + * @param anomalyDetectionIndices anomaly detection indices + * @param clientUtil client wrapper + * @param indexUtils Index util classes + * @param clusterService accessor to ES cluster service + */ + public AnomalyIndexHandler( + Client client, + Settings settings, + ThreadPool threadPool, + String indexName, + ADIndexManagement anomalyDetectionIndices, + ClientUtil clientUtil, + IndexUtils indexUtils, + ClusterService clusterService + ) { + this.client = client; + this.threadPool = threadPool; + this.savingBackoffPolicy = BackoffPolicy + .exponentialBackoff( + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY.get(settings), + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF.get(settings) + ); + this.indexName = indexName; + this.anomalyDetectionIndices = anomalyDetectionIndices; + this.fixedDoc = false; + this.clientUtil = clientUtil; + this.indexUtils = indexUtils; + this.clusterService = clusterService; + } + + /** + * Since the constructor needs to provide injected value and Guice does not allow Boolean to be there + * (claiming it does not know how to instantiate it), caller needs to manually set it to true if + * it want to save to a specific doc. + * @param fixedDoc whether to save to a specific doc Id + */ + public void setFixedDoc(boolean fixedDoc) { + this.fixedDoc = fixedDoc; + } + + // TODO: check if user has permission to index. + public void index(T toSave, String detectorId, String customIndexName) { + if (indexUtils.checkIndicesBlocked(clusterService.state(), ClusterBlockLevel.WRITE, this.indexName)) { + LOG.warn(String.format(Locale.ROOT, CANNOT_SAVE_ERR_MSG, detectorId)); + return; + } + + try { + if (customIndexName != null) { + // Only create custom AD result index when create detector, won’t recreate custom AD result index in realtime + // job and historical analysis later if it’s deleted. If user delete the custom AD result index, and AD plugin + // recreate it, that may bring confusion. + if (!anomalyDetectionIndices.doesIndexExist(customIndexName)) { + throw new EndRunException(detectorId, CAN_NOT_FIND_RESULT_INDEX + customIndexName, true); + } + if (!anomalyDetectionIndices.isValidResultIndexMapping(customIndexName)) { + throw new EndRunException(detectorId, "wrong index mapping of custom AD result index", true); + } + save(toSave, detectorId, customIndexName); + return; + } + if (!anomalyDetectionIndices.doesDefaultResultIndexExist()) { + anomalyDetectionIndices + .initDefaultResultIndexDirectly( + ActionListener.wrap(initResponse -> onCreateIndexResponse(initResponse, toSave, detectorId), exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + save(toSave, detectorId); + } else { + throw new TimeSeriesException( + detectorId, + String.format(Locale.ROOT, "Unexpected error creating index %s", indexName), + exception + ); + } + }) + ); + } else { + save(toSave, detectorId); + } + } catch (Exception e) { + throw new TimeSeriesException( + detectorId, + String.format(Locale.ROOT, "Error in saving %s for detector %s", indexName, detectorId), + e + ); + } + } + + private void onCreateIndexResponse(CreateIndexResponse response, T toSave, String detectorId) { + if (response.isAcknowledged()) { + save(toSave, detectorId); + } else { + throw new TimeSeriesException( + detectorId, + String.format(Locale.ROOT, "Creating %s with mappings call not acknowledged.", indexName) + ); + } + } + + protected void save(T toSave, String detectorId) { + save(toSave, detectorId, indexName); + } + + // TODO: Upgrade custom result index mapping to latest version? + // It may bring some issue if we upgrade the custom result index mapping while user is using that index + // for other use cases. One easy solution is to tell user only use custom result index for AD plugin. + // For the first release of custom result index, it's not a issue. Will leave this to next phase. + protected void save(T toSave, String detectorId, String indexName) { + try (XContentBuilder builder = jsonBuilder()) { + IndexRequest indexRequest = new IndexRequest(indexName).source(toSave.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)); + if (fixedDoc) { + indexRequest.id(detectorId); + } + + saveIteration(indexRequest, detectorId, savingBackoffPolicy.iterator()); + } catch (Exception e) { + LOG.error(String.format(Locale.ROOT, "Failed to save %s", indexName), e); + throw new TimeSeriesException(detectorId, String.format(Locale.ROOT, "Cannot save %s", indexName)); + } + } + + void saveIteration(IndexRequest indexRequest, String detectorId, Iterator backoff) { + clientUtil + .asyncRequest( + indexRequest, + client::index, + ActionListener + .wrap( + response -> { LOG.debug(String.format(Locale.ROOT, SUCCESS_SAVING_MSG, detectorId)); }, + exception -> { + // OpenSearch has a thread pool and a queue for write per node. A thread + // pool will have N number of workers ready to handle the requests. When a + // request comes and if a worker is free , this is handled by the worker. Now by + // default the number of workers is equal to the number of cores on that CPU. + // When the workers are full and there are more write requests, the request + // will go to queue. The size of queue is also limited. If by default size is, + // say, 200 and if there happens more parallel requests than this, then those + // requests would be rejected as you can see OpenSearchRejectedExecutionException. + // So OpenSearchRejectedExecutionException is the way that OpenSearch tells us that + // it cannot keep up with the current indexing rate. + // When it happens, we should pause indexing a bit before trying again, ideally + // with randomized exponential backoff. + Throwable cause = ExceptionsHelper.unwrapCause(exception); + if (!(cause instanceof OpenSearchRejectedExecutionException) || !backoff.hasNext()) { + LOG.error(String.format(Locale.ROOT, FAIL_TO_SAVE_ERR_MSG, detectorId), cause); + } else { + TimeValue nextDelay = backoff.next(); + LOG.warn(String.format(Locale.ROOT, RETRY_SAVING_ERR_MSG, detectorId), cause); + threadPool + .schedule( + () -> saveIteration(BulkUtil.cloneIndexRequest(indexRequest), detectorId, backoff), + nextDelay, + ThreadPool.Names.SAME + ); + } + } + ) + ); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandler.java-e b/src/main/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandler.java-e new file mode 100644 index 000000000..c021ead73 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandler.java-e @@ -0,0 +1,147 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport.handler; + +import static org.opensearch.ad.constant.ADCommonName.ANOMALY_RESULT_INDEX_ALIAS; +import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; +import static org.opensearch.timeseries.constant.CommonMessages.CAN_NOT_FIND_RESULT_INDEX; + +import java.util.List; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.bulk.BulkRequestBuilder; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.ad.util.IndexUtils; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.util.RestHandlerUtils; + +public class AnomalyResultBulkIndexHandler extends AnomalyIndexHandler { + private static final Logger LOG = LogManager.getLogger(AnomalyResultBulkIndexHandler.class); + + private ADIndexManagement anomalyDetectionIndices; + + public AnomalyResultBulkIndexHandler( + Client client, + Settings settings, + ThreadPool threadPool, + ClientUtil clientUtil, + IndexUtils indexUtils, + ClusterService clusterService, + ADIndexManagement anomalyDetectionIndices + ) { + super(client, settings, threadPool, ANOMALY_RESULT_INDEX_ALIAS, anomalyDetectionIndices, clientUtil, indexUtils, clusterService); + this.anomalyDetectionIndices = anomalyDetectionIndices; + } + + /** + * Bulk index anomaly results. Create anomaly result index first if it doesn't exist. + * + * @param resultIndex anomaly result index + * @param anomalyResults anomaly results + * @param listener action listener + */ + public void bulkIndexAnomalyResult(String resultIndex, List anomalyResults, ActionListener listener) { + if (anomalyResults == null || anomalyResults.size() == 0) { + listener.onResponse(null); + return; + } + String detectorId = anomalyResults.get(0).getConfigId(); + try { + if (resultIndex != null) { + // Only create custom AD result index when create detector, won’t recreate custom AD result index in realtime + // job and historical analysis later if it’s deleted. If user delete the custom AD result index, and AD plugin + // recreate it, that may bring confusion. + if (!anomalyDetectionIndices.doesIndexExist(resultIndex)) { + throw new EndRunException(detectorId, CAN_NOT_FIND_RESULT_INDEX + resultIndex, true); + } + if (!anomalyDetectionIndices.isValidResultIndexMapping(resultIndex)) { + throw new EndRunException(detectorId, "wrong index mapping of custom AD result index", true); + } + bulkSaveDetectorResult(resultIndex, anomalyResults, listener); + return; + } + if (!anomalyDetectionIndices.doesDefaultResultIndexExist()) { + anomalyDetectionIndices.initDefaultResultIndexDirectly(ActionListener.wrap(response -> { + if (response.isAcknowledged()) { + bulkSaveDetectorResult(anomalyResults, listener); + } else { + String error = "Creating anomaly result index with mappings call not acknowledged"; + LOG.error(error); + listener.onFailure(new TimeSeriesException(error)); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + bulkSaveDetectorResult(anomalyResults, listener); + } else { + listener.onFailure(exception); + } + })); + } else { + bulkSaveDetectorResult(anomalyResults, listener); + } + } catch (TimeSeriesException e) { + listener.onFailure(e); + } catch (Exception e) { + String error = "Failed to bulk index anomaly result"; + LOG.error(error, e); + listener.onFailure(new TimeSeriesException(error, e)); + } + } + + private void bulkSaveDetectorResult(List anomalyResults, ActionListener listener) { + bulkSaveDetectorResult(ANOMALY_RESULT_INDEX_ALIAS, anomalyResults, listener); + } + + private void bulkSaveDetectorResult(String resultIndex, List anomalyResults, ActionListener listener) { + BulkRequestBuilder bulkRequestBuilder = client.prepareBulk(); + anomalyResults.forEach(anomalyResult -> { + try (XContentBuilder builder = jsonBuilder()) { + IndexRequest indexRequest = new IndexRequest(resultIndex) + .source(anomalyResult.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)); + bulkRequestBuilder.add(indexRequest); + } catch (Exception e) { + String error = "Failed to prepare request to bulk index anomaly results"; + LOG.error(error, e); + throw new TimeSeriesException(error); + } + }); + client.bulk(bulkRequestBuilder.request(), ActionListener.wrap(r -> { + if (r.hasFailures()) { + String failureMessage = r.buildFailureMessage(); + LOG.warn("Failed to bulk index AD result " + failureMessage); + listener.onFailure(new TimeSeriesException(failureMessage)); + } else { + listener.onResponse(r); + } + + }, e -> { + LOG.error("bulk index ad result failed", e); + listener.onFailure(e); + })); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/handler/MultiEntityResultHandler.java-e b/src/main/java/org/opensearch/ad/transport/handler/MultiEntityResultHandler.java-e new file mode 100644 index 000000000..d9d98b74a --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/handler/MultiEntityResultHandler.java-e @@ -0,0 +1,123 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport.handler; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.transport.ADResultBulkAction; +import org.opensearch.ad.transport.ADResultBulkRequest; +import org.opensearch.ad.transport.ADResultBulkResponse; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.ad.util.IndexUtils; +import org.opensearch.client.Client; +import org.opensearch.cluster.block.ClusterBlockLevel; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.common.exception.TimeSeriesException; + +/** + * EntityResultTransportAction depends on this class. I cannot use + * AnomalyIndexHandler < AnomalyResult > . All transport actions + * needs dependency injection. Guice has a hard time initializing generics class + * AnomalyIndexHandler < AnomalyResult > due to type erasure. + * To avoid that, I create a class with a built-in details so + * that Guice would be able to work out the details. + * + */ +public class MultiEntityResultHandler extends AnomalyIndexHandler { + private static final Logger LOG = LogManager.getLogger(MultiEntityResultHandler.class); + // package private for testing + static final String SUCCESS_SAVING_RESULT_MSG = "Result saved successfully."; + static final String CANNOT_SAVE_RESULT_ERR_MSG = "Cannot save results due to write block."; + + @Inject + public MultiEntityResultHandler( + Client client, + Settings settings, + ThreadPool threadPool, + ADIndexManagement anomalyDetectionIndices, + ClientUtil clientUtil, + IndexUtils indexUtils, + ClusterService clusterService + ) { + super( + client, + settings, + threadPool, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + clientUtil, + indexUtils, + clusterService + ); + } + + /** + * Execute the bulk request + * @param currentBulkRequest The bulk request + * @param listener callback after flushing + */ + public void flush(ADResultBulkRequest currentBulkRequest, ActionListener listener) { + if (indexUtils.checkIndicesBlocked(clusterService.state(), ClusterBlockLevel.WRITE, this.indexName)) { + listener.onFailure(new TimeSeriesException(CANNOT_SAVE_RESULT_ERR_MSG)); + return; + } + + try { + if (!anomalyDetectionIndices.doesDefaultResultIndexExist()) { + anomalyDetectionIndices.initDefaultResultIndexDirectly(ActionListener.wrap(initResponse -> { + if (initResponse.isAcknowledged()) { + bulk(currentBulkRequest, listener); + } else { + LOG.warn("Creating result index with mappings call not acknowledged."); + listener.onFailure(new TimeSeriesException("", "Creating result index with mappings call not acknowledged.")); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + bulk(currentBulkRequest, listener); + } else { + LOG.warn("Unexpected error creating result index", exception); + listener.onFailure(exception); + } + })); + } else { + bulk(currentBulkRequest, listener); + } + } catch (Exception e) { + LOG.warn("Error in bulking results", e); + listener.onFailure(e); + } + } + + private void bulk(ADResultBulkRequest currentBulkRequest, ActionListener listener) { + if (currentBulkRequest.numberOfActions() <= 0) { + listener.onFailure(new TimeSeriesException("no result to save")); + return; + } + client.execute(ADResultBulkAction.INSTANCE, currentBulkRequest, ActionListener.wrap(response -> { + LOG.debug(SUCCESS_SAVING_RESULT_MSG); + listener.onResponse(response); + }, exception -> { + LOG.error("Error in bulking results", exception); + listener.onFailure(exception); + })); + } +} diff --git a/src/main/java/org/opensearch/ad/util/ADSafeSecurityInjector.java-e b/src/main/java/org/opensearch/ad/util/ADSafeSecurityInjector.java-e new file mode 100644 index 000000000..749a7434c --- /dev/null +++ b/src/main/java/org/opensearch/ad/util/ADSafeSecurityInjector.java-e @@ -0,0 +1,78 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.Strings; +import org.opensearch.timeseries.common.exception.EndRunException; + +public class ADSafeSecurityInjector extends SafeSecurityInjector { + private static final Logger LOG = LogManager.getLogger(ADSafeSecurityInjector.class); + private NodeStateManager nodeStateManager; + + public ADSafeSecurityInjector(String detectorId, Settings settings, ThreadContext tc, NodeStateManager stateManager) { + super(detectorId, settings, tc); + this.nodeStateManager = stateManager; + } + + public void injectUserRolesFromDetector(ActionListener injectListener) { + // if id is null, we cannot fetch a detector + if (Strings.isEmpty(id)) { + LOG.debug("Empty id"); + injectListener.onResponse(null); + return; + } + + // for example, if a user exists in thread context, we don't need to inject user/roles + if (!shouldInject()) { + LOG.debug("Don't need to inject"); + injectListener.onResponse(null); + return; + } + + ActionListener> getDetectorListener = ActionListener.wrap(detectorOp -> { + if (!detectorOp.isPresent()) { + injectListener.onFailure(new EndRunException(id, "AnomalyDetector is not available.", false)); + return; + } + AnomalyDetector detector = detectorOp.get(); + User userInfo = SecurityUtil.getUserFromDetector(detector, settings); + inject(userInfo.getName(), userInfo.getRoles()); + injectListener.onResponse(null); + }, injectListener::onFailure); + + // Since we are gonna read user from detector, make sure the anomaly detector exists and fetched from disk or cached memory + // We don't accept a passed-in AnomalyDetector because the caller might mistakenly not insert any user info in the + // constructed AnomalyDetector and thus poses risks. In the case, if the user is null, we will give admin role. + nodeStateManager.getAnomalyDetector(id, getDetectorListener); + } + + public void injectUserRoles(User user) { + if (user == null) { + LOG.debug("null user"); + return; + } + + if (shouldInject()) { + inject(user.getName(), user.getRoles()); + } + } +} diff --git a/src/main/java/org/opensearch/ad/util/BulkUtil.java-e b/src/main/java/org/opensearch/ad/util/BulkUtil.java-e new file mode 100644 index 000000000..d7fe9c6f6 --- /dev/null +++ b/src/main/java/org/opensearch/ad/util/BulkUtil.java-e @@ -0,0 +1,70 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.index.IndexRequest; + +public class BulkUtil { + private static final Logger logger = LogManager.getLogger(BulkUtil.class); + + public static List getFailedIndexRequest(BulkRequest bulkRequest, BulkResponse bulkResponse) { + List res = new ArrayList<>(); + + if (bulkResponse == null || bulkRequest == null) { + return res; + } + + Set failedId = new HashSet<>(); + for (BulkItemResponse response : bulkResponse.getItems()) { + if (response.isFailed() && ExceptionUtil.isRetryAble(response.getFailure().getStatus())) { + failedId.add(response.getId()); + } + } + + for (DocWriteRequest request : bulkRequest.requests()) { + try { + if (failedId.contains(request.id())) { + res.add((IndexRequest) request); + } + } catch (ClassCastException e) { + logger.error("We only support IndexRequest"); + throw e; + } + + } + return res; + } + + /** + * Copy original request's source without other information like autoGeneratedTimestamp. + * otherwise, an exception will be thrown indicating autoGeneratedTimestamp should not be set + * while request id is already set (id is set because we have already sent the request before). + * @param indexRequest request to be cloned + * @return cloned Request + */ + public static IndexRequest cloneIndexRequest(IndexRequest indexRequest) { + IndexRequest newRequest = new IndexRequest(indexRequest.index()); + newRequest.source(indexRequest.source(), indexRequest.getContentType()); + return newRequest; + } +} diff --git a/src/main/java/org/opensearch/ad/util/ClientUtil.java-e b/src/main/java/org/opensearch/ad/util/ClientUtil.java-e new file mode 100644 index 000000000..d85d4fdf7 --- /dev/null +++ b/src/main/java/org/opensearch/ad/util/ClientUtil.java-e @@ -0,0 +1,332 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.REQUEST_TIMEOUT; + +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; +import java.util.function.Function; + +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchException; +import org.opensearch.OpenSearchTimeoutException; +import org.opensearch.action.ActionFuture; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionResponse; +import org.opensearch.action.ActionType; +import org.opensearch.action.LatchedActionListener; +import org.opensearch.action.TaskOperationFailure; +import org.opensearch.action.admin.cluster.node.tasks.cancel.CancelTasksAction; +import org.opensearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest; +import org.opensearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse; +import org.opensearch.action.admin.cluster.node.tasks.list.ListTasksAction; +import org.opensearch.action.admin.cluster.node.tasks.list.ListTasksRequest; +import org.opensearch.action.admin.cluster.node.tasks.list.ListTasksResponse; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskId; +import org.opensearch.tasks.TaskInfo; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.common.exception.InternalFailure; +import org.opensearch.timeseries.constant.CommonMessages; + +public class ClientUtil { + private volatile TimeValue requestTimeout; + private Client client; + private final Throttler throttler; + private ThreadPool threadPool; + + @Inject + public ClientUtil(Settings setting, Client client, Throttler throttler, ThreadPool threadPool) { + this.requestTimeout = REQUEST_TIMEOUT.get(setting); + this.client = client; + this.throttler = throttler; + this.threadPool = threadPool; + } + + /** + * Send a nonblocking request with a timeout and return response. Blocking is not allowed in a + * transport call context. See BaseFuture.blockingAllowed + * @param request request like index/search/get + * @param LOG log + * @param consumer functional interface to operate as a client request like client::get + * @param ActionRequest + * @param ActionResponse + * @return the response + * @throws OpenSearchTimeoutException when we cannot get response within time. + * @throws IllegalStateException when the waiting thread is interrupted + */ + public Optional timedRequest( + Request request, + Logger LOG, + BiConsumer> consumer + ) { + try { + AtomicReference respReference = new AtomicReference<>(); + final CountDownLatch latch = new CountDownLatch(1); + + consumer + .accept( + request, + new LatchedActionListener( + ActionListener + .wrap( + response -> { respReference.set(response); }, + exception -> { LOG.error("Cannot get response for request {}, error: {}", request, exception); } + ), + latch + ) + ); + + if (!latch.await(requestTimeout.getSeconds(), TimeUnit.SECONDS)) { + throw new OpenSearchTimeoutException("Cannot get response within time limit: " + request.toString()); + } + return Optional.ofNullable(respReference.get()); + } catch (InterruptedException e1) { + LOG.error(CommonMessages.WAIT_ERR_MSG); + throw new IllegalStateException(e1); + } + } + + /** + * Send an asynchronous request and handle response with the provided listener. + * @param ActionRequest + * @param ActionResponse + * @param request request body + * @param consumer request method, functional interface to operate as a client request like client::get + * @param listener needed to handle response + */ + public void asyncRequest( + Request request, + BiConsumer> consumer, + ActionListener listener + ) { + consumer + .accept( + request, + ActionListener.wrap(response -> { listener.onResponse(response); }, exception -> { listener.onFailure(exception); }) + ); + } + + /** + * Execute a transport action and handle response with the provided listener. + * @param ActionRequest + * @param ActionResponse + * @param action transport action + * @param request request body + * @param listener needed to handle response + */ + public void execute( + ActionType action, + Request request, + ActionListener listener + ) { + client + .execute( + action, + request, + ActionListener.wrap(response -> { listener.onResponse(response); }, exception -> { listener.onFailure(exception); }) + ); + } + + /** + * Send an synchronous request and handle response with the provided listener. + * + * @deprecated use asyncRequest with listener instead. + * + * @param ActionRequest + * @param ActionResponse + * @param request request body + * @param function request method, functional interface to operate as a client request like client::get + * @return the response + */ + @Deprecated + public Response syncRequest( + Request request, + Function> function + ) { + return function.apply(request).actionGet(requestTimeout); + } + + /** + * Send a nonblocking request with a timeout and return response. + * If there is already a query running on given detector, it will try to + * cancel the query. Otherwise it will add this query to the negative cache + * and then attach the AnomalyDetection specific header to the request. + * Once the request complete, it will be removed from the negative cache. + * @param ActionRequest + * @param ActionResponse + * @param request request like index/search/get + * @param LOG log + * @param consumer functional interface to operate as a client request like client::get + * @param detector Anomaly Detector + * @return the response + * @throws InternalFailure when there is already a query running + * @throws OpenSearchTimeoutException when we cannot get response within time. + * @throws IllegalStateException when the waiting thread is interrupted + */ + public Optional throttledTimedRequest( + Request request, + Logger LOG, + BiConsumer> consumer, + AnomalyDetector detector + ) { + + try { + String detectorId = detector.getId(); + if (!throttler.insertFilteredQuery(detectorId, request)) { + LOG.info("There is one query running for detectorId: {}. Trying to cancel the long running query", detectorId); + cancelRunningQuery(client, detectorId, LOG); + throw new InternalFailure(detector.getId(), "There is already a query running on AnomalyDetector"); + } + AtomicReference respReference = new AtomicReference<>(); + final CountDownLatch latch = new CountDownLatch(1); + + try (ThreadContext.StoredContext context = threadPool.getThreadContext().stashContext()) { + assert context != null; + threadPool.getThreadContext().putHeader(Task.X_OPAQUE_ID, ADCommonName.ANOMALY_DETECTOR + ":" + detectorId); + consumer.accept(request, new LatchedActionListener(ActionListener.wrap(response -> { + // clear negative cache + throttler.clearFilteredQuery(detectorId); + respReference.set(response); + }, exception -> { + // clear negative cache + throttler.clearFilteredQuery(detectorId); + LOG.error("Cannot get response for request {}, error: {}", request, exception); + }), latch)); + } catch (Exception e) { + LOG.error("Failed to process the request for detectorId: {}.", detectorId); + throttler.clearFilteredQuery(detectorId); + throw e; + } + + if (!latch.await(requestTimeout.getSeconds(), TimeUnit.SECONDS)) { + throw new OpenSearchTimeoutException("Cannot get response within time limit: " + request.toString()); + } + return Optional.ofNullable(respReference.get()); + } catch (InterruptedException e1) { + LOG.error(CommonMessages.WAIT_ERR_MSG); + throw new IllegalStateException(e1); + } + } + + /** + * Check if there is running query on given detector + * @param detector Anomaly Detector + * @return true if given detector has a running query else false + */ + public boolean hasRunningQuery(AnomalyDetector detector) { + return throttler.getFilteredQuery(detector.getId()).isPresent(); + } + + /** + * Cancel long running query for given detectorId + * @param client OpenSearch client + * @param detectorId Anomaly Detector Id + * @param LOG Logger + */ + private void cancelRunningQuery(Client client, String detectorId, Logger LOG) { + ListTasksRequest listTasksRequest = new ListTasksRequest(); + listTasksRequest.setActions("*search*"); + client + .execute( + ListTasksAction.INSTANCE, + listTasksRequest, + ActionListener.wrap(response -> { onListTaskResponse(response, detectorId, LOG); }, exception -> { + LOG.error("List Tasks failed.", exception); + throw new InternalFailure(detectorId, "Failed to list current tasks", exception); + }) + ); + } + + /** + * Helper function to handle ListTasksResponse + * @param listTasksResponse ListTasksResponse + * @param detectorId Anomaly Detector Id + * @param LOG Logger + */ + private void onListTaskResponse(ListTasksResponse listTasksResponse, String detectorId, Logger LOG) { + List tasks = listTasksResponse.getTasks(); + TaskId matchedParentTaskId = null; + TaskId matchedSingleTaskId = null; + for (TaskInfo task : tasks) { + if (!task.getHeaders().isEmpty() + && task.getHeaders().get(Task.X_OPAQUE_ID).equals(ADCommonName.ANOMALY_DETECTOR + ":" + detectorId)) { + if (!task.getParentTaskId().equals(TaskId.EMPTY_TASK_ID)) { + // we found the parent task, don't need to check more + matchedParentTaskId = task.getParentTaskId(); + break; + } else { + // we found one task, keep checking other tasks + matchedSingleTaskId = task.getTaskId(); + } + } + } + // case 1: given detectorId is not in current task list + if (matchedParentTaskId == null && matchedSingleTaskId == null) { + // log and then clear negative cache + LOG.info("Couldn't find task for detectorId: {}. Clean this entry from Throttler", detectorId); + throttler.clearFilteredQuery(detectorId); + return; + } + // case 2: we can find the task for given detectorId + CancelTasksRequest cancelTaskRequest = new CancelTasksRequest(); + if (matchedParentTaskId != null) { + cancelTaskRequest.setParentTaskId(matchedParentTaskId); + LOG.info("Start to cancel task for parentTaskId: {}", matchedParentTaskId.toString()); + } else { + cancelTaskRequest.setTaskId(matchedSingleTaskId); + LOG.info("Start to cancel task for taskId: {}", matchedSingleTaskId.toString()); + } + + client + .execute( + CancelTasksAction.INSTANCE, + cancelTaskRequest, + ActionListener.wrap(response -> { onCancelTaskResponse(response, detectorId, LOG); }, exception -> { + LOG.error("Failed to cancel task for detectorId: " + detectorId, exception); + throw new InternalFailure(detectorId, "Failed to cancel current tasks", exception); + }) + ); + } + + /** + * Helper function to handle CancelTasksResponse + * @param cancelTasksResponse CancelTasksResponse + * @param detectorId Anomaly Detector Id + * @param LOG Logger + */ + private void onCancelTaskResponse(CancelTasksResponse cancelTasksResponse, String detectorId, Logger LOG) { + // todo: adding retry mechanism + List nodeFailures = cancelTasksResponse.getNodeFailures(); + List taskFailures = cancelTasksResponse.getTaskFailures(); + if (nodeFailures.isEmpty() && taskFailures.isEmpty()) { + LOG.info("Cancelling query for detectorId: {} succeeds. Clear entry from Throttler", detectorId); + throttler.clearFilteredQuery(detectorId); + return; + } + LOG.error("Failed to cancel task for detectorId: " + detectorId); + throw new InternalFailure(detectorId, "Failed to cancel current tasks due to node or task failures"); + } +} diff --git a/src/main/java/org/opensearch/ad/util/DateUtils.java-e b/src/main/java/org/opensearch/ad/util/DateUtils.java-e new file mode 100644 index 000000000..e7cfc21ce --- /dev/null +++ b/src/main/java/org/opensearch/ad/util/DateUtils.java-e @@ -0,0 +1,39 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import java.time.Duration; +import java.time.Instant; +import java.time.ZoneId; +import java.time.ZonedDateTime; +import java.time.temporal.ChronoField; + +import org.opensearch.common.unit.TimeValue; + +public class DateUtils { + public static final ZoneId UTC = ZoneId.of("Z"); + + /** + * Get hour of day of the input time in UTC + * @param instant input time + * + * @return Hour of day + */ + public static int getUTCHourOfDay(Instant instant) { + ZonedDateTime time = ZonedDateTime.ofInstant(instant, UTC); + return time.get(ChronoField.HOUR_OF_DAY); + } + + public static Duration toDuration(TimeValue timeValue) { + return Duration.ofMillis(timeValue.millis()); + } +} diff --git a/src/main/java/org/opensearch/ad/util/ExceptionUtil.java b/src/main/java/org/opensearch/ad/util/ExceptionUtil.java index 56e0e9856..b48cf49e4 100644 --- a/src/main/java/org/opensearch/ad/util/ExceptionUtil.java +++ b/src/main/java/org/opensearch/ad/util/ExceptionUtil.java @@ -23,10 +23,10 @@ import org.opensearch.action.UnavailableShardsException; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.replication.ReplicationResponse; -import org.opensearch.common.io.stream.NotSerializableExceptionWrapper; +import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; +import org.opensearch.core.rest.RestStatus; import org.opensearch.index.IndexNotFoundException; -import org.opensearch.rest.RestStatus; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.common.exception.TimeSeriesException; diff --git a/src/main/java/org/opensearch/ad/util/ExceptionUtil.java-e b/src/main/java/org/opensearch/ad/util/ExceptionUtil.java-e new file mode 100644 index 000000000..e0190df8b --- /dev/null +++ b/src/main/java/org/opensearch/ad/util/ExceptionUtil.java-e @@ -0,0 +1,193 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import java.util.EnumSet; +import java.util.concurrent.RejectedExecutionException; + +import org.apache.commons.lang.exception.ExceptionUtils; +import org.apache.logging.log4j.core.util.Throwables; +import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.NoShardAvailableActionException; +import org.opensearch.action.UnavailableShardsException; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.replication.ReplicationResponse; +import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; +import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; + +public class ExceptionUtil { + // a positive cache of retriable error rest status + private static final EnumSet RETRYABLE_STATUS = EnumSet + .of(RestStatus.REQUEST_TIMEOUT, RestStatus.CONFLICT, RestStatus.INTERNAL_SERVER_ERROR); + + /** + * OpenSearch restricts the kind of exceptions can be thrown over the wire + * (See OpenSearchException.OpenSearchExceptionHandle). Since we cannot + * add our own exception like ResourceNotFoundException without modifying + * OpenSearch's code, we have to unwrap the remote transport exception and + * check its root cause message. + * + * @param exception exception thrown locally or over the wire + * @param expected expected root cause + * @param expectedExceptionName expected exception name + * @return whether the exception wraps the expected exception as the cause + */ + public static boolean isException(Throwable exception, Class expected, String expectedExceptionName) { + if (exception == null) { + return false; + } + + if (expected.isAssignableFrom(exception.getClass())) { + return true; + } + + // all exception that has not been registered to sent over wire can be wrapped + // inside NotSerializableExceptionWrapper. + // see StreamOutput.writeException + // OpenSearchException.getExceptionName(exception) returns exception + // separated by underscore. For example, ResourceNotFoundException is converted + // to "resource_not_found_exception". + if (exception instanceof NotSerializableExceptionWrapper && exception.getMessage().trim().startsWith(expectedExceptionName)) { + return true; + } + return false; + } + + /** + * Get failure of all shards. + * + * @param response index response + * @return composite failures of all shards + */ + public static String getShardsFailure(IndexResponse response) { + StringBuilder failureReasons = new StringBuilder(); + if (response.getShardInfo() != null && response.getShardInfo().getFailed() > 0) { + for (ReplicationResponse.ShardInfo.Failure failure : response.getShardInfo().getFailures()) { + failureReasons.append(failure.reason()); + } + return failureReasons.toString(); + } + return null; + } + + /** + * Count exception in AD failure stats of not. + * + * @param e exception + * @return true if should count in AD failure stats; otherwise return false + */ + public static boolean countInStats(Exception e) { + if (!(e instanceof TimeSeriesException) || ((TimeSeriesException) e).isCountedInStats()) { + return true; + } + return false; + } + + /** + * Get error message from exception. + * + * @param e exception + * @return readable error message or full stack trace + */ + public static String getErrorMessage(Exception e) { + if (e instanceof IllegalArgumentException || e instanceof TimeSeriesException) { + return e.getMessage(); + } else if (e instanceof OpenSearchException) { + return ((OpenSearchException) e).getDetailedMessage(); + } else { + return ExceptionUtils.getFullStackTrace(e); + } + } + + /** + * + * @param exception Exception + * @return whether the cause indicates the cluster is overloaded + */ + public static boolean isOverloaded(Throwable exception) { + Throwable cause = Throwables.getRootCause(exception); + // LimitExceededException may indicate circuit breaker exception + // UnavailableShardsException can happen when the system cannot respond + // to requests + return cause instanceof RejectedExecutionException + || cause instanceof OpenSearchRejectedExecutionException + || cause instanceof UnavailableShardsException + || cause instanceof LimitExceededException; + } + + public static boolean isRetryAble(Exception e) { + Throwable cause = ExceptionsHelper.unwrapCause(e); + RestStatus status = ExceptionsHelper.status(cause); + return isRetryAble(status); + } + + public static boolean isRetryAble(RestStatus status) { + return RETRYABLE_STATUS.contains(status); + } + + /** + * Wrap a listener to return the given exception no matter what + * @param The type of listener response + * @param original Original listener + * @param exceptionToReturn The exception to return + * @param detectorId Detector Id + * @return the wrapped listener + */ + public static ActionListener wrapListener(ActionListener original, Exception exceptionToReturn, String detectorId) { + return ActionListener + .wrap( + r -> { original.onFailure(exceptionToReturn); }, + e -> { original.onFailure(selectHigherPriorityException(exceptionToReturn, e)); } + ); + } + + /** + * Return an exception that has higher priority. + * If an exception is EndRunException while another one is not, the former has + * higher priority. + * If both exceptions are EndRunException, the one with end now true has higher + * priority. + * Otherwise, return the second given exception. + * @param exception1 Exception 1 + * @param exception2 Exception 2 + * @return high priority exception + */ + public static Exception selectHigherPriorityException(Exception exception1, Exception exception2) { + if (exception1 instanceof EndRunException) { + // we have already had EndRunException. Don't replace it with something less severe + EndRunException endRunException = (EndRunException) exception1; + if (endRunException.isEndNow()) { + // don't proceed if recorded exception is ending now + return exception1; + } + if (false == (exception2 instanceof EndRunException) || false == ((EndRunException) exception2).isEndNow()) { + // don't proceed if the giving exception is not ending now + return exception1; + } + } + return exception2; + } + + public static boolean isIndexNotAvailable(Exception e) { + if (e == null) { + return false; + } + return e instanceof IndexNotFoundException || e instanceof NoShardAvailableActionException; + } +} diff --git a/src/main/java/org/opensearch/ad/util/IndexUtils.java-e b/src/main/java/org/opensearch/ad/util/IndexUtils.java-e new file mode 100644 index 000000000..b69c0924a --- /dev/null +++ b/src/main/java/org/opensearch/ad/util/IndexUtils.java-e @@ -0,0 +1,147 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import java.util.List; +import java.util.Locale; +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.admin.indices.stats.IndicesStatsRequest; +import org.opensearch.action.admin.indices.stats.IndicesStatsResponse; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.block.ClusterBlockLevel; +import org.opensearch.cluster.health.ClusterIndexHealth; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; + +public class IndexUtils { + /** + * Status string of index that does not exist + */ + public static final String NONEXISTENT_INDEX_STATUS = "non-existent"; + + /** + * Status string when an alias exists, but does not point to an index + */ + public static final String ALIAS_EXISTS_NO_INDICES_STATUS = "alias exists, but does not point to any indices"; + public static final String ALIAS_POINTS_TO_MULTIPLE_INDICES_STATUS = "alias exists, but does not point to any " + "indices"; + + private static final Logger logger = LogManager.getLogger(IndexUtils.class); + + private Client client; + private ClientUtil clientUtil; + private ClusterService clusterService; + private final IndexNameExpressionResolver indexNameExpressionResolver; + + /** + * Inject annotation required by Guice to instantiate EntityResultTransportAction (transitive dependency) + * + * @param client Client to make calls to OpenSearch + * @param clientUtil AD Client utility + * @param clusterService ES ClusterService + * @param indexNameExpressionResolver index name resolver + */ + @Inject + public IndexUtils( + Client client, + ClientUtil clientUtil, + ClusterService clusterService, + IndexNameExpressionResolver indexNameExpressionResolver + ) { + this.client = client; + this.clientUtil = clientUtil; + this.clusterService = clusterService; + this.indexNameExpressionResolver = indexNameExpressionResolver; + } + + /** + * Gets the cluster index health for a particular index or the index an alias points to + * + * If an alias is passed in, it will only return the health status of an index it points to if it only points to a + * single index. If it points to multiple indices, it will throw an exception. + * + * @param indexOrAliasName String of the index or alias name to get health of. + * @return String represents the status of the index: "red", "yellow" or "green" + * @throws IllegalArgumentException Thrown when an alias is passed in that points to more than one index + */ + public String getIndexHealthStatus(String indexOrAliasName) throws IllegalArgumentException { + if (!clusterService.state().getRoutingTable().hasIndex(indexOrAliasName)) { + // Check if the index is actually an alias + if (clusterService.state().metadata().hasAlias(indexOrAliasName)) { + // List of all indices the alias refers to + List indexMetaDataList = clusterService + .state() + .metadata() + .getIndicesLookup() + .get(indexOrAliasName) + .getIndices(); + if (indexMetaDataList.size() == 0) { + return ALIAS_EXISTS_NO_INDICES_STATUS; + } else if (indexMetaDataList.size() > 1) { + throw new IllegalArgumentException("Cannot get health for alias that points to multiple indices"); + } else { + indexOrAliasName = indexMetaDataList.get(0).getIndex().getName(); + } + } else { + return NONEXISTENT_INDEX_STATUS; + } + } + + ClusterIndexHealth indexHealth = new ClusterIndexHealth( + clusterService.state().metadata().index(indexOrAliasName), + clusterService.state().getRoutingTable().index(indexOrAliasName) + ); + + return indexHealth.getStatus().name().toLowerCase(Locale.ROOT); + } + + /** + * Gets the number of documents in an index. + * + * @deprecated + * + * @param indexName Name of the index + * @return The number of documents in an index. 0 is returned if the index does not exist. -1 is returned if the + * request fails. + */ + @Deprecated + public Long getNumberOfDocumentsInIndex(String indexName) { + if (!clusterService.state().getRoutingTable().hasIndex(indexName)) { + return 0L; + } + IndicesStatsRequest indicesStatsRequest = new IndicesStatsRequest(); + Optional response = clientUtil.timedRequest(indicesStatsRequest, logger, client.admin().indices()::stats); + return response.map(r -> r.getIndex(indexName).getPrimaries().docs.getCount()).orElse(-1L); + } + + /** + * Similar to checkGlobalBlock, we check block on the indices level. + * + * @param state Cluster state + * @param level block level + * @param indices the indices on which to check block + * @return whether any of the index has block on the level. + */ + public boolean checkIndicesBlocked(ClusterState state, ClusterBlockLevel level, String... indices) { + // the original index might be an index expression with wildcards like "log*", + // so we need to expand the expression to concrete index name + String[] concreteIndices = indexNameExpressionResolver.concreteIndexNames(state, IndicesOptions.lenientExpandOpen(), indices); + + return state.blocks().indicesBlockedException(level, concreteIndices) != null; + } +} diff --git a/src/main/java/org/opensearch/ad/util/MultiResponsesDelegateActionListener.java-e b/src/main/java/org/opensearch/ad/util/MultiResponsesDelegateActionListener.java-e new file mode 100644 index 000000000..8b18bf9c3 --- /dev/null +++ b/src/main/java/org/opensearch/ad/util/MultiResponsesDelegateActionListener.java-e @@ -0,0 +1,107 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.model.Mergeable; + +/** + * A listener wrapper to help send multiple requests asynchronously and return one final responses together + */ +public class MultiResponsesDelegateActionListener implements ActionListener { + private static final Logger LOG = LogManager.getLogger(MultiResponsesDelegateActionListener.class); + static final String NO_RESPONSE = "No response collected"; + + private final ActionListener delegate; + private final AtomicInteger collectedResponseCount; + private final AtomicInteger maxResponseCount; + // save responses from multiple requests + private final List savedResponses; + private List exceptions; + private String finalErrorMsg; + private final boolean returnOnPartialResults; + + public MultiResponsesDelegateActionListener( + ActionListener delegate, + int maxResponseCount, + String finalErrorMsg, + boolean returnOnPartialResults + ) { + this.delegate = delegate; + this.collectedResponseCount = new AtomicInteger(0); + this.maxResponseCount = new AtomicInteger(maxResponseCount); + this.savedResponses = Collections.synchronizedList(new ArrayList()); + this.exceptions = Collections.synchronizedList(new ArrayList()); + this.finalErrorMsg = finalErrorMsg; + this.returnOnPartialResults = returnOnPartialResults; + } + + @Override + public void onResponse(T response) { + try { + if (response != null) { + this.savedResponses.add(response); + } + } finally { + // If expectedResponseCount == 0 , collectedResponseCount.incrementAndGet() will be greater than expectedResponseCount + if (collectedResponseCount.incrementAndGet() >= maxResponseCount.get()) { + finish(); + } + } + + } + + @Override + public void onFailure(Exception e) { + LOG.error("Failure in response", e); + try { + this.exceptions.add(e.getMessage()); + } finally { + // no matter the asynchronous request is a failure or success, we need to increment the count. + // We need finally here to increment the count when there is a failure. + if (collectedResponseCount.incrementAndGet() >= maxResponseCount.get()) { + finish(); + } + } + } + + private void finish() { + if (this.returnOnPartialResults || this.exceptions.size() == 0) { + if (this.exceptions.size() > 0) { + LOG.error(String.format(Locale.ROOT, "Although returning result, there exists exceptions: %s", this.exceptions)); + } + handleSavedResponses(); + } else { + this.delegate.onFailure(new RuntimeException(String.format(Locale.ROOT, finalErrorMsg + " Exceptions: %s", exceptions))); + } + } + + private void handleSavedResponses() { + if (savedResponses.size() == 0) { + this.delegate.onFailure(new RuntimeException(NO_RESPONSE)); + } else { + T response0 = savedResponses.get(0); + for (int i = 1; i < savedResponses.size(); i++) { + response0.merge(savedResponses.get(i)); + } + this.delegate.onResponse(response0); + } + } +} diff --git a/src/main/java/org/opensearch/ad/util/SafeSecurityInjector.java-e b/src/main/java/org/opensearch/ad/util/SafeSecurityInjector.java-e new file mode 100644 index 000000000..612ea4d5c --- /dev/null +++ b/src/main/java/org/opensearch/ad/util/SafeSecurityInjector.java-e @@ -0,0 +1,87 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import java.util.List; +import java.util.Locale; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.commons.InjectSecurity; + +public abstract class SafeSecurityInjector implements AutoCloseable { + private static final Logger LOG = LogManager.getLogger(SafeSecurityInjector.class); + // user header used by security plugin. As we cannot take security plugin as + // a compile dependency, we have to duplicate it here. + private static final String OPENDISTRO_SECURITY_USER = "_opendistro_security_user"; + + private InjectSecurity rolesInjectorHelper; + protected String id; + protected Settings settings; + protected ThreadContext tc; + + public SafeSecurityInjector(String id, Settings settings, ThreadContext tc) { + this.id = id; + this.settings = settings; + this.tc = tc; + this.rolesInjectorHelper = null; + } + + protected boolean shouldInject() { + if (id == null || settings == null || tc == null) { + LOG.debug(String.format(Locale.ROOT, "null value: id: %s, settings: %s, threadContext: %s", id, settings, tc)); + return false; + } + // user not null means the request comes from user (e.g., public restful API) + // we don't need to inject roles. + Object userIn = tc.getTransient(OPENDISTRO_SECURITY_USER); + if (userIn != null) { + LOG.debug(new ParameterizedMessage("User not empty in thread context: [{}]", userIn)); + return false; + } + userIn = tc.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + if (userIn != null) { + LOG.debug(new ParameterizedMessage("User not empty in thread context: [{}]", userIn)); + return false; + } + Object rolesin = tc.getTransient(ConfigConstants.OPENSEARCH_SECURITY_INJECTED_ROLES); + if (rolesin != null) { + LOG.warn(new ParameterizedMessage("Injected roles not empty in thread context: [{}]", rolesin)); + return false; + } + + return true; + } + + protected void inject(String user, List roles) { + if (roles == null) { + LOG.warn("Cannot inject empty roles in thread context"); + return; + } + if (rolesInjectorHelper == null) { + // lazy init + rolesInjectorHelper = new InjectSecurity(id, settings, tc); + } + rolesInjectorHelper.inject(user, roles); + } + + @Override + public void close() { + if (rolesInjectorHelper != null) { + rolesInjectorHelper.close(); + } + } +} diff --git a/src/main/java/org/opensearch/ad/util/SecurityClientUtil.java-e b/src/main/java/org/opensearch/ad/util/SecurityClientUtil.java-e new file mode 100644 index 000000000..8e9b97b57 --- /dev/null +++ b/src/main/java/org/opensearch/ad/util/SecurityClientUtil.java-e @@ -0,0 +1,130 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import java.util.function.BiConsumer; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionResponse; +import org.opensearch.action.ActionType; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; + +public class SecurityClientUtil { + private static final String INJECTION_ID = "direct"; + private NodeStateManager nodeStateManager; + private Settings settings; + + @Inject + public SecurityClientUtil(NodeStateManager nodeStateManager, Settings settings) { + this.nodeStateManager = nodeStateManager; + this.settings = settings; + } + + /** + * Send an asynchronous request in the context of user role and handle response with the provided listener. The role + * is recorded in a detector config. + * @param ActionRequest + * @param ActionResponse + * @param request request body + * @param consumer request method, functional interface to operate as a client request like client::get + * @param detectorId Detector id + * @param client OpenSearch client + * @param listener needed to handle response + */ + public void asyncRequestWithInjectedSecurity( + Request request, + BiConsumer> consumer, + String detectorId, + Client client, + ActionListener listener + ) { + ThreadContext threadContext = client.threadPool().getThreadContext(); + try (ADSafeSecurityInjector injectSecurity = new ADSafeSecurityInjector(detectorId, settings, threadContext, nodeStateManager)) { + injectSecurity + .injectUserRolesFromDetector( + ActionListener + .wrap( + success -> consumer.accept(request, ActionListener.runBefore(listener, () -> injectSecurity.close())), + listener::onFailure + ) + ); + } + } + + /** + * Send an asynchronous request in the context of user role and handle response with the provided listener. The role + * is provided in the arguments. + * @param ActionRequest + * @param ActionResponse + * @param request request body + * @param consumer request method, functional interface to operate as a client request like client::get + * @param user User info + * @param client OpenSearch client + * @param listener needed to handle response + */ + public void asyncRequestWithInjectedSecurity( + Request request, + BiConsumer> consumer, + User user, + Client client, + ActionListener listener + ) { + ThreadContext threadContext = client.threadPool().getThreadContext(); + // use a hardcoded string as detector id that is only used in logging + // Question: + // Will the try-with-resources statement auto close injectSecurity? + // Here the injectSecurity is closed explicitly. So we don't need to put the injectSecurity inside try ? + // Explanation: + // There might be two threads: one thread covers try, inject, and triggers client.execute/client.search + // (this can be a thread in the write thread pool); another thread actually execute the logic of + // client.execute/client.search and handles the responses (this can be a thread in the search thread pool). + // Auto-close in try will restore the context in one thread; the explicit close injectSecurity will restore + // the context in another thread. So we still need to put the injectSecurity inside try. + try (ADSafeSecurityInjector injectSecurity = new ADSafeSecurityInjector(INJECTION_ID, settings, threadContext, nodeStateManager)) { + injectSecurity.injectUserRoles(user); + consumer.accept(request, ActionListener.runBefore(listener, () -> injectSecurity.close())); + } + } + + /** + * Execute a transport action in the context of user role and handle response with the provided listener. The role + * is provided in the arguments. + * @param ActionRequest + * @param ActionResponse + * @param action transport action + * @param request request body + * @param user User info + * @param client OpenSearch client + * @param listener needed to handle response + */ + public void executeWithInjectedSecurity( + ActionType action, + Request request, + User user, + Client client, + ActionListener listener + ) { + ThreadContext threadContext = client.threadPool().getThreadContext(); + + // use a hardcoded string as detector id that is only used in logging + try (ADSafeSecurityInjector injectSecurity = new ADSafeSecurityInjector(INJECTION_ID, settings, threadContext, nodeStateManager)) { + injectSecurity.injectUserRoles(user); + client.execute(action, request, ActionListener.runBefore(listener, () -> injectSecurity.close())); + } + } +} diff --git a/src/main/java/org/opensearch/ad/util/SecurityUtil.java-e b/src/main/java/org/opensearch/ad/util/SecurityUtil.java-e new file mode 100644 index 000000000..d72d345ab --- /dev/null +++ b/src/main/java/org/opensearch/ad/util/SecurityUtil.java-e @@ -0,0 +1,77 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import java.util.Collections; +import java.util.List; + +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.common.settings.Settings; +import org.opensearch.commons.authuser.User; + +import com.google.common.collect.ImmutableList; + +public class SecurityUtil { + /** + * @param userObj the last user who edited the detector config + * @param settings Node settings + * @return converted user for bwc if necessary + */ + private static User getAdjustedUserBWC(User userObj, Settings settings) { + /* + * We need to handle 3 cases: + * 1. Detectors created by older versions and never updated. These detectors wont have User details in the + * detector object. `detector.user` will be null. Insert `all_access, AmazonES_all_access` role. + * 2. Detectors are created when security plugin is disabled, these will have empty User object. + * (`detector.user.name`, `detector.user.roles` are empty ) + * 3. Detectors are created when security plugin is enabled, these will have an User object. + * This will inject user role and check if the user role has permissions to call the execute + * Anomaly Result API. + */ + String user; + List roles; + if (userObj == null) { + // It's possible that user create domain with security disabled, then enable security + // after upgrading. This is for BWC, for old detectors which created when security + // disabled, the user will be null. + // This is a huge promotion in privileges. To prevent a caller code from making a mistake and pass a null object, + // we make the method private and only allow fetching user object from detector or job configuration (see the public + // access methods with the same name). + user = ""; + roles = settings.getAsList("", ImmutableList.of("all_access", "AmazonES_all_access")); + return new User(user, Collections.emptyList(), roles, Collections.emptyList()); + } else { + return userObj; + } + } + + /** + * * + * @param detector Detector config + * @param settings Node settings + * @return user recorded by a detector. Made adjstument for BWC (backward-compatibility) if necessary. + */ + public static User getUserFromDetector(AnomalyDetector detector, Settings settings) { + return getAdjustedUserBWC(detector.getUser(), settings); + } + + /** + * * + * @param detectorJob Detector Job + * @param settings Node settings + * @return user recorded by a detector job + */ + public static User getUserFromJob(AnomalyDetectorJob detectorJob, Settings settings) { + return getAdjustedUserBWC(detectorJob.getUser(), settings); + } +} diff --git a/src/main/java/org/opensearch/ad/util/Throttler.java-e b/src/main/java/org/opensearch/ad/util/Throttler.java-e new file mode 100644 index 000000000..177b612a2 --- /dev/null +++ b/src/main/java/org/opensearch/ad/util/Throttler.java-e @@ -0,0 +1,73 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import java.time.Clock; +import java.time.Instant; +import java.util.AbstractMap; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; + +import org.opensearch.action.ActionRequest; + +/** + * Utility functions for throttling query. + */ +public class Throttler { + // negativeCache is used to reject search query if given detector already has one query running + // key is detectorId, value is an entry. Key is ActionRequest and value is the timestamp + private final ConcurrentHashMap> negativeCache; + private final Clock clock; + + public Throttler(Clock clock) { + this.negativeCache = new ConcurrentHashMap<>(); + this.clock = clock; + } + + /** + * This will be used when dependency injection directly/indirectly injects a Throttler object. Without this object, + * node start might fail due to not being able to find a Clock object. We removed Clock object association in + * https://github.com/opendistro-for-elasticsearch/anomaly-detection/pull/305 + */ + public Throttler() { + this(Clock.systemUTC()); + } + + /** + * Get negative cache value(ActionRequest, Instant) for given detector + * @param detectorId AnomalyDetector ID + * @return negative cache value(ActionRequest, Instant) + */ + public Optional> getFilteredQuery(String detectorId) { + return Optional.ofNullable(negativeCache.get(detectorId)); + } + + /** + * Insert the negative cache entry for given detector + * If key already exists, return false. Otherwise true. + * @param detectorId AnomalyDetector ID + * @param request ActionRequest + * @return true if key doesn't exist otherwise false. + */ + public synchronized boolean insertFilteredQuery(String detectorId, ActionRequest request) { + return negativeCache.putIfAbsent(detectorId, new AbstractMap.SimpleEntry<>(request, clock.instant())) == null; + } + + /** + * Clear the negative cache for given detector. + * @param detectorId AnomalyDetector ID + */ + public void clearFilteredQuery(String detectorId) { + negativeCache.remove(detectorId); + } +} diff --git a/src/main/java/org/opensearch/forecast/constant/ForecastCommonMessages.java-e b/src/main/java/org/opensearch/forecast/constant/ForecastCommonMessages.java-e new file mode 100644 index 000000000..46de0c762 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/constant/ForecastCommonMessages.java-e @@ -0,0 +1,54 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.constant; + +import static org.opensearch.forecast.constant.ForecastCommonName.CUSTOM_RESULT_INDEX_PREFIX; + +public class ForecastCommonMessages { + // ====================================== + // Validation message + // ====================================== + public static String INVALID_FORECAST_INTERVAL = "Forecast interval must be a positive integer"; + public static String NULL_FORECAST_INTERVAL = "Forecast interval should be set"; + public static String INVALID_FORECASTER_NAME = + "Valid characters for forecaster name are a-z, A-Z, 0-9, -(hyphen), _(underscore) and .(period)"; + + // ====================================== + // Resource constraints + // ====================================== + public static final String DISABLED_ERR_MSG = "Forecast functionality is disabled. To enable update plugins.forecast.enabled to true"; + + // ====================================== + // RESTful API + // ====================================== + public static String FAIL_TO_CREATE_FORECASTER = "Failed to create forecaster"; + public static String FAIL_TO_UPDATE_FORECASTER = "Failed to update forecaster"; + public static String FAIL_TO_FIND_FORECASTER_MSG = "Can not find forecaster with id: "; + public static final String FORECASTER_ID_MISSING_MSG = "Forecaster ID is missing"; + public static final String INVALID_TIMESTAMP_ERR_MSG = "timestamp is invalid"; + + // ====================================== + // Security + // ====================================== + public static String NO_PERMISSION_TO_ACCESS_FORECASTER = "User does not have permissions to access forecaster: "; + public static String FAIL_TO_GET_USER_INFO = "Unable to get user information from forecaster "; + + // ====================================== + // Used for custom forecast result index + // ====================================== + public static String INVALID_RESULT_INDEX_PREFIX = "Result index must start with " + CUSTOM_RESULT_INDEX_PREFIX; + + // ====================================== + // Task + // ====================================== + public static String FORECASTER_IS_RUNNING = "Forecaster is already running"; +} diff --git a/src/main/java/org/opensearch/forecast/constant/ForecastCommonName.java-e b/src/main/java/org/opensearch/forecast/constant/ForecastCommonName.java-e new file mode 100644 index 000000000..8edaf2d2b --- /dev/null +++ b/src/main/java/org/opensearch/forecast/constant/ForecastCommonName.java-e @@ -0,0 +1,48 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.constant; + +public class ForecastCommonName { + // ====================================== + // Validation + // ====================================== + // detector validation aspect + public static final String FORECASTER_ASPECT = "forecaster"; + + // ====================================== + // Used for custom forecast result index + // ====================================== + public static final String DUMMY_FORECAST_RESULT_ID = "dummy_forecast_result_id"; + public static final String DUMMY_FORECASTER_ID = "dummy_forecaster_id"; + public static final String CUSTOM_RESULT_INDEX_PREFIX = "opensearch-forecast-result-"; + + // ====================================== + // Index name + // ====================================== + // index name for forecast checkpoint of each model. One model one document. + public static final String FORECAST_CHECKPOINT_INDEX_NAME = ".opensearch-forecast-checkpoints"; + // index name for forecast state. Will store forecast task in this index as well. + public static final String FORECAST_STATE_INDEX = ".opensearch-forecast-state"; + // The alias of the index in which to write forecast result history. Not a hidden index. + // Allow users to create dashboard or query freely on top of it. + public static final String FORECAST_RESULT_INDEX_ALIAS = "opensearch-forecast-results"; + + // ====================================== + // Used in toXContent + // ====================================== + public static final String ID_JSON_KEY = "forecasterID"; + + // ====================================== + // Used in stats API + // ====================================== + public static final String FORECASTER_ID_KEY = "forecaster_id"; +} diff --git a/src/main/java/org/opensearch/forecast/constant/ForecastCommonValue.java b/src/main/java/org/opensearch/forecast/constant/ForecastCommonValue.java new file mode 100644 index 000000000..27a4de5ed --- /dev/null +++ b/src/main/java/org/opensearch/forecast/constant/ForecastCommonValue.java @@ -0,0 +1,17 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.constant; + +public class ForecastCommonValue { + public static String INTERNAL_ACTION_PREFIX = "cluster:admin/plugin/forecastinternal/"; + public static String EXTERNAL_ACTION_PREFIX = "cluster:admin/plugin/forecast/"; +} diff --git a/src/main/java/org/opensearch/forecast/constant/ForecastCommonValue.java-e b/src/main/java/org/opensearch/forecast/constant/ForecastCommonValue.java-e new file mode 100644 index 000000000..27a4de5ed --- /dev/null +++ b/src/main/java/org/opensearch/forecast/constant/ForecastCommonValue.java-e @@ -0,0 +1,17 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.constant; + +public class ForecastCommonValue { + public static String INTERNAL_ACTION_PREFIX = "cluster:admin/plugin/forecastinternal/"; + public static String EXTERNAL_ACTION_PREFIX = "cluster:admin/plugin/forecast/"; +} diff --git a/src/main/java/org/opensearch/forecast/indices/ForecastIndex.java-e b/src/main/java/org/opensearch/forecast/indices/ForecastIndex.java-e new file mode 100644 index 000000000..8e514dd6e --- /dev/null +++ b/src/main/java/org/opensearch/forecast/indices/ForecastIndex.java-e @@ -0,0 +1,72 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.indices; + +import java.util.function.Supplier; + +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.ThrowingSupplierWrapper; +import org.opensearch.timeseries.indices.TimeSeriesIndex; + +public enum ForecastIndex implements TimeSeriesIndex { + // throw RuntimeException since we don't know how to handle the case when the mapping reading throws IOException + RESULT( + ForecastCommonName.FORECAST_RESULT_INDEX_ALIAS, + true, + ThrowingSupplierWrapper.throwingSupplierWrapper(ForecastIndexManagement::getResultMappings) + ), + CONFIG(CommonName.CONFIG_INDEX, false, ThrowingSupplierWrapper.throwingSupplierWrapper(ADIndexManagement::getConfigMappings)), + JOB(CommonName.JOB_INDEX, false, ThrowingSupplierWrapper.throwingSupplierWrapper(ADIndexManagement::getJobMappings)), + CHECKPOINT( + ForecastCommonName.FORECAST_CHECKPOINT_INDEX_NAME, + false, + ThrowingSupplierWrapper.throwingSupplierWrapper(ForecastIndexManagement::getCheckpointMappings) + ), + STATE( + ForecastCommonName.FORECAST_STATE_INDEX, + false, + ThrowingSupplierWrapper.throwingSupplierWrapper(ForecastIndexManagement::getStateMappings) + ); + + private final String indexName; + // whether we use an alias for the index + private final boolean alias; + private final String mapping; + + ForecastIndex(String name, boolean alias, Supplier mappingSupplier) { + this.indexName = name; + this.alias = alias; + this.mapping = mappingSupplier.get(); + } + + @Override + public String getIndexName() { + return indexName; + } + + @Override + public boolean isAlias() { + return alias; + } + + @Override + public String getMapping() { + return mapping; + } + + @Override + public boolean isJobIndex() { + return CommonName.JOB_INDEX.equals(indexName); + } +} diff --git a/src/main/java/org/opensearch/forecast/indices/ForecastIndexManagement.java-e b/src/main/java/org/opensearch/forecast/indices/ForecastIndexManagement.java-e new file mode 100644 index 000000000..db8b40d42 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/indices/ForecastIndexManagement.java-e @@ -0,0 +1,277 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.indices; + +import static org.opensearch.cluster.metadata.IndexMetadata.SETTING_REPLICATION_TYPE; +import static org.opensearch.forecast.constant.ForecastCommonName.DUMMY_FORECAST_RESULT_ID; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_CHECKPOINT_INDEX_MAPPING_FILE; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_MAX_PRIMARY_SHARDS; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_RESULTS_INDEX_MAPPING_FILE; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_RESULT_HISTORY_MAX_DOCS_PER_SHARD; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_RESULT_HISTORY_RETENTION_PERIOD; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_RESULT_HISTORY_ROLLOVER_PERIOD; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_STATE_INDEX_MAPPING_FILE; +import static org.opensearch.indices.replication.common.ReplicationType.DOCUMENT; + +import java.io.IOException; +import java.util.EnumMap; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +public class ForecastIndexManagement extends IndexManagement { + private static final Logger logger = LogManager.getLogger(ForecastIndexManagement.class); + + // The index name pattern to query all the forecast result history indices + public static final String FORECAST_RESULT_HISTORY_INDEX_PATTERN = ""; + + // The index name pattern to query all forecast results, history and current forecast results + public static final String ALL_FORECAST_RESULTS_INDEX_PATTERN = "opensearch-forecast-results*"; + + /** + * Constructor function + * + * @param client OS client supports administrative actions + * @param clusterService OS cluster service + * @param threadPool OS thread pool + * @param settings OS cluster setting + * @param nodeFilter Used to filter eligible nodes to host forecast indices + * @param maxUpdateRunningTimes max number of retries to update index mapping and setting + * @throws IOException when failing to get mapping file + */ + public ForecastIndexManagement( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + Settings settings, + DiscoveryNodeFilterer nodeFilter, + int maxUpdateRunningTimes + ) + throws IOException { + super( + client, + clusterService, + threadPool, + settings, + nodeFilter, + maxUpdateRunningTimes, + ForecastIndex.class, + FORECAST_MAX_PRIMARY_SHARDS.get(settings), + FORECAST_RESULT_HISTORY_ROLLOVER_PERIOD.get(settings), + FORECAST_RESULT_HISTORY_MAX_DOCS_PER_SHARD.get(settings), + FORECAST_RESULT_HISTORY_RETENTION_PERIOD.get(settings), + ForecastIndex.RESULT.getMapping() + ); + this.indexStates = new EnumMap(ForecastIndex.class); + + this.clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(FORECAST_RESULT_HISTORY_MAX_DOCS_PER_SHARD, it -> historyMaxDocs = it); + + this.clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_RESULT_HISTORY_ROLLOVER_PERIOD, it -> { + historyRolloverPeriod = it; + rescheduleRollover(); + }); + this.clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(FORECAST_RESULT_HISTORY_RETENTION_PERIOD, it -> { historyRetentionPeriod = it; }); + + this.clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_MAX_PRIMARY_SHARDS, it -> maxPrimaryShards = it); + + this.updateRunningTimes = 0; + } + + /** + * Get forecast result index mapping json content. + * + * @return forecast result index mapping + * @throws IOException IOException if mapping file can't be read correctly + */ + public static String getResultMappings() throws IOException { + return getMappings(FORECAST_RESULTS_INDEX_MAPPING_FILE); + } + + /** + * Get forecaster state index mapping json content. + * + * @return forecaster state index mapping + * @throws IOException IOException if mapping file can't be read correctly + */ + public static String getStateMappings() throws IOException { + String forecastStateMappings = getMappings(FORECAST_STATE_INDEX_MAPPING_FILE); + String forecasterIndexMappings = getConfigMappings(); + forecasterIndexMappings = forecasterIndexMappings + .substring(forecasterIndexMappings.indexOf("\"properties\""), forecasterIndexMappings.lastIndexOf("}")); + return forecastStateMappings.replace("FORECASTER_INDEX_MAPPING_PLACE_HOLDER", forecasterIndexMappings); + } + + /** + * Get checkpoint index mapping json content. + * + * @return checkpoint index mapping + * @throws IOException IOException if mapping file can't be read correctly + */ + public static String getCheckpointMappings() throws IOException { + return getMappings(FORECAST_CHECKPOINT_INDEX_MAPPING_FILE); + } + + /** + * default forecaster result index exist or not. + * + * @return true if default forecaster result index exists + */ + @Override + public boolean doesDefaultResultIndexExist() { + return doesAliasExist(ForecastCommonName.FORECAST_RESULT_INDEX_ALIAS); + } + + /** + * Forecast state index exist or not. + * + * @return true if forecast state index exists + */ + @Override + public boolean doesStateIndexExist() { + return doesIndexExist(ForecastCommonName.FORECAST_STATE_INDEX); + } + + /** + * Checkpoint index exist or not. + * + * @return true if checkpoint index exists + */ + @Override + public boolean doesCheckpointIndexExist() { + return doesIndexExist(ForecastCommonName.FORECAST_CHECKPOINT_INDEX_NAME); + } + + /** + * Create the state index. + * + * @param actionListener action called after create index + */ + @Override + public void initStateIndex(ActionListener actionListener) { + try { + Settings replicationSettings = Settings.builder().put(SETTING_REPLICATION_TYPE, DOCUMENT.name()).build(); + CreateIndexRequest request = new CreateIndexRequest(ForecastCommonName.FORECAST_STATE_INDEX, replicationSettings) + .mapping(getStateMappings(), XContentType.JSON) + .settings(settings); + adminClient.indices().create(request, markMappingUpToDate(ForecastIndex.STATE, actionListener)); + } catch (IOException e) { + logger.error("Fail to init AD detection state index", e); + actionListener.onFailure(e); + } + } + + /** + * Create the checkpoint index. + * + * @param actionListener action called after create index + * @throws EndRunException EndRunException due to failure to get mapping + */ + @Override + public void initCheckpointIndex(ActionListener actionListener) { + String mapping; + try { + mapping = getCheckpointMappings(); + } catch (IOException e) { + throw new EndRunException("", "Cannot find checkpoint mapping file", true); + } + // forecast indices need RAW (e.g., we want users to be able to consume forecast results as soon as + // possible and send out an alert if a threshold is breached). + Settings replicationSettings = Settings.builder().put(SETTING_REPLICATION_TYPE, DOCUMENT.name()).build(); + CreateIndexRequest request = new CreateIndexRequest(ForecastCommonName.FORECAST_CHECKPOINT_INDEX_NAME, replicationSettings) + .mapping(mapping, XContentType.JSON); + choosePrimaryShards(request, true); + adminClient.indices().create(request, markMappingUpToDate(ForecastIndex.CHECKPOINT, actionListener)); + } + + @Override + protected void rolloverAndDeleteHistoryIndex() { + rolloverAndDeleteHistoryIndex( + ForecastCommonName.FORECAST_RESULT_INDEX_ALIAS, + ALL_FORECAST_RESULTS_INDEX_PATTERN, + FORECAST_RESULT_HISTORY_INDEX_PATTERN, + ForecastIndex.RESULT + ); + } + + /** + * Create config index directly. + * + * @param actionListener action called after create index + * @throws IOException IOException from {@link IndexManagement#getConfigMappings} + */ + @Override + public void initConfigIndex(ActionListener actionListener) throws IOException { + super.initConfigIndex(markMappingUpToDate(ForecastIndex.CONFIG, actionListener)); + } + + /** + * Create config index. + * + * @param actionListener action called after create index + */ + @Override + public void initJobIndex(ActionListener actionListener) { + super.initJobIndex(markMappingUpToDate(ForecastIndex.JOB, actionListener)); + } + + @Override + protected IndexRequest createDummyIndexRequest(String resultIndex) throws IOException { + ForecastResult dummyResult = ForecastResult.getDummyResult(); + return new IndexRequest(resultIndex) + .id(DUMMY_FORECAST_RESULT_ID) + .source(dummyResult.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS)); + } + + @Override + protected DeleteRequest createDummyDeleteRequest(String resultIndex) throws IOException { + return new DeleteRequest(resultIndex).id(DUMMY_FORECAST_RESULT_ID); + } + + @Override + public void initDefaultResultIndexDirectly(ActionListener actionListener) { + initResultIndexDirectly( + FORECAST_RESULT_HISTORY_INDEX_PATTERN, + ForecastIndex.RESULT.getIndexName(), + false, + FORECAST_RESULT_HISTORY_INDEX_PATTERN, + ForecastIndex.RESULT, + actionListener + ); + } + + @Override + public void initCustomResultIndexDirectly(String resultIndex, ActionListener actionListener) { + // throws IOException { + initResultIndexDirectly(resultIndex, null, false, FORECAST_RESULT_HISTORY_INDEX_PATTERN, ForecastIndex.RESULT, actionListener); + } +} diff --git a/src/main/java/org/opensearch/forecast/model/ForecastResult.java b/src/main/java/org/opensearch/forecast/model/ForecastResult.java index 3d1042e2c..1ce75ff63 100644 --- a/src/main/java/org/opensearch/forecast/model/ForecastResult.java +++ b/src/main/java/org/opensearch/forecast/model/ForecastResult.java @@ -11,7 +11,7 @@ package org.opensearch.forecast.model; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.forecast.constant.ForecastCommonName.DUMMY_FORECASTER_ID; import java.io.IOException; @@ -21,10 +21,10 @@ import java.util.Optional; import org.apache.commons.lang.builder.ToStringBuilder; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.commons.authuser.User; import org.opensearch.core.ParseField; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; diff --git a/src/main/java/org/opensearch/forecast/model/ForecastResult.java-e b/src/main/java/org/opensearch/forecast/model/ForecastResult.java-e new file mode 100644 index 000000000..3a4e29493 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/model/ForecastResult.java-e @@ -0,0 +1,590 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.forecast.constant.ForecastCommonName.DUMMY_FORECASTER_ID; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import org.apache.commons.lang.builder.ToStringBuilder; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.ParseField; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.timeseries.annotation.Generated; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.constant.CommonValue; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.util.ParseUtils; + +import com.google.common.base.Objects; + +/** + * Include result returned from RCF model and feature data. + */ +public class ForecastResult extends IndexableResult { + public static final String PARSE_FIELD_NAME = "ForecastResult"; + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( + ForecastResult.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) + ); + + public static final String FEATURE_ID_FIELD = "feature_id"; + public static final String VALUE_FIELD = "forecast_value"; + public static final String LOWER_BOUND_FIELD = "forecast_lower_bound"; + public static final String UPPER_BOUND_FIELD = "forecast_upper_bound"; + public static final String INTERVAL_WIDTH_FIELD = "confidence_interval_width"; + public static final String FORECAST_DATA_START_TIME_FIELD = "forecast_data_start_time"; + public static final String FORECAST_DATA_END_TIME_FIELD = "forecast_data_end_time"; + public static final String HORIZON_INDEX_FIELD = "horizon_index"; + + private final String featureId; + private final Float forecastValue; + private final Float lowerBound; + private final Float upperBound; + private final Float confidenceIntervalWidth; + private final Instant forecastDataStartTime; + private final Instant forecastDataEndTime; + private final Integer horizonIndex; + protected final Double dataQuality; + + // used when indexing exception or error or an empty result + public ForecastResult( + String forecasterId, + String taskId, + List featureData, + Instant dataStartTime, + Instant dataEndTime, + Instant executionStartTime, + Instant executionEndTime, + String error, + Optional entity, + User user, + Integer schemaVersion, + String modelId + ) { + this( + forecasterId, + taskId, + Double.NaN, + featureData, + dataStartTime, + dataEndTime, + executionStartTime, + executionEndTime, + error, + entity, + user, + schemaVersion, + modelId, + null, + null, + null, + null, + null, + null, + null + ); + } + + public ForecastResult( + String forecasterId, + String taskId, + Double dataQuality, + List featureData, + Instant dataStartTime, + Instant dataEndTime, + Instant executionStartTime, + Instant executionEndTime, + String error, + Optional entity, + User user, + Integer schemaVersion, + String modelId, + String featureId, + Float forecastValue, + Float lowerBound, + Float upperBound, + Instant forecastDataStartTime, + Instant forecastDataEndTime, + Integer horizonIndex + ) { + super( + forecasterId, + featureData, + dataStartTime, + dataEndTime, + executionStartTime, + executionEndTime, + error, + entity, + user, + schemaVersion, + modelId, + taskId + ); + this.featureId = featureId; + this.dataQuality = dataQuality; + this.forecastValue = forecastValue; + this.lowerBound = lowerBound; + this.upperBound = upperBound; + this.confidenceIntervalWidth = lowerBound != null && upperBound != null ? Math.abs(upperBound - lowerBound) : Float.NaN; + this.forecastDataStartTime = forecastDataStartTime; + this.forecastDataEndTime = forecastDataEndTime; + this.horizonIndex = horizonIndex; + } + + public static List fromRawRCFCasterResult( + String forecasterId, + long intervalMillis, + Double dataQuality, + List featureData, + Instant dataStartTime, + Instant dataEndTime, + Instant executionStartTime, + Instant executionEndTime, + String error, + Optional entity, + User user, + Integer schemaVersion, + String modelId, + float[] forecastsValues, + float[] forecastsUppers, + float[] forecastsLowers, + String taskId + ) { + int inputLength = featureData.size(); + int numberOfForecasts = forecastsValues.length / inputLength; + + List convertedForecastValues = new ArrayList<>(numberOfForecasts); + + // store feature data and forecast value separately for easy query on feature data + // we can join them using forecasterId, entityId, and executionStartTime/executionEndTime + convertedForecastValues + .add( + new ForecastResult( + forecasterId, + taskId, + dataQuality, + featureData, + dataStartTime, + dataEndTime, + executionStartTime, + executionEndTime, + error, + entity, + user, + schemaVersion, + modelId, + null, + null, + null, + null, + null, + null, + -1 + ) + ); + Instant forecastDataStartTime = dataEndTime; + + for (int i = 0; i < numberOfForecasts; i++) { + Instant forecastDataEndTime = forecastDataStartTime.plusMillis(intervalMillis); + for (int j = 0; j < inputLength; j++) { + int k = i * inputLength + j; + convertedForecastValues + .add( + new ForecastResult( + forecasterId, + taskId, + dataQuality, + null, + null, + null, + executionStartTime, + executionEndTime, + error, + entity, + user, + schemaVersion, + modelId, + featureData.get(j).getFeatureId(), + forecastsValues[k], + forecastsLowers[k], + forecastsUppers[k], + forecastDataStartTime, + forecastDataEndTime, + i + ) + ); + } + forecastDataStartTime = forecastDataEndTime; + } + + return convertedForecastValues; + } + + public ForecastResult(StreamInput input) throws IOException { + super(input); + this.featureId = input.readOptionalString(); + this.dataQuality = input.readOptionalDouble(); + this.forecastValue = input.readOptionalFloat(); + this.lowerBound = input.readOptionalFloat(); + this.upperBound = input.readOptionalFloat(); + this.confidenceIntervalWidth = input.readOptionalFloat(); + this.forecastDataStartTime = input.readOptionalInstant(); + this.forecastDataEndTime = input.readOptionalInstant(); + this.horizonIndex = input.readOptionalInt(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder + .startObject() + .field(ForecastCommonName.FORECASTER_ID_KEY, configId) + .field(CommonName.SCHEMA_VERSION_FIELD, schemaVersion); + + if (dataStartTime != null) { + xContentBuilder.field(CommonName.DATA_START_TIME_FIELD, dataStartTime.toEpochMilli()); + } + if (dataEndTime != null) { + xContentBuilder.field(CommonName.DATA_END_TIME_FIELD, dataEndTime.toEpochMilli()); + } + if (featureData != null) { + // can be null during preview + xContentBuilder.field(CommonName.FEATURE_DATA_FIELD, featureData.toArray()); + } + if (executionStartTime != null) { + // can be null during preview + xContentBuilder.field(CommonName.EXECUTION_START_TIME_FIELD, executionStartTime.toEpochMilli()); + } + if (executionEndTime != null) { + // can be null during preview + xContentBuilder.field(CommonName.EXECUTION_END_TIME_FIELD, executionEndTime.toEpochMilli()); + } + if (error != null) { + xContentBuilder.field(CommonName.ERROR_FIELD, error); + } + if (optionalEntity.isPresent()) { + xContentBuilder.field(CommonName.ENTITY_FIELD, optionalEntity.get()); + } + if (user != null) { + xContentBuilder.field(CommonName.USER_FIELD, user); + } + if (modelId != null) { + xContentBuilder.field(CommonName.MODEL_ID_FIELD, modelId); + } + if (dataQuality != null && !dataQuality.isNaN()) { + xContentBuilder.field(CommonName.DATA_QUALITY_FIELD, dataQuality); + } + if (taskId != null) { + xContentBuilder.field(CommonName.TASK_ID_FIELD, taskId); + } + if (entityId != null) { + xContentBuilder.field(CommonName.ENTITY_ID_FIELD, entityId); + } + if (forecastValue != null) { + xContentBuilder.field(VALUE_FIELD, forecastValue); + } + if (lowerBound != null) { + xContentBuilder.field(LOWER_BOUND_FIELD, lowerBound); + } + if (upperBound != null) { + xContentBuilder.field(UPPER_BOUND_FIELD, upperBound); + } + if (forecastDataStartTime != null) { + xContentBuilder.field(FORECAST_DATA_START_TIME_FIELD, forecastDataStartTime.toEpochMilli()); + } + if (forecastDataEndTime != null) { + xContentBuilder.field(FORECAST_DATA_END_TIME_FIELD, forecastDataEndTime.toEpochMilli()); + } + if (horizonIndex != null) { + xContentBuilder.field(HORIZON_INDEX_FIELD, horizonIndex); + } + if (featureId != null) { + xContentBuilder.field(FEATURE_ID_FIELD, featureId); + } + + return xContentBuilder.endObject(); + } + + public static ForecastResult parse(XContentParser parser) throws IOException { + String forecasterId = null; + Double dataQuality = null; + List featureData = null; + Instant dataStartTime = null; + Instant dataEndTime = null; + Instant executionStartTime = null; + Instant executionEndTime = null; + String error = null; + Entity entity = null; + User user = null; + Integer schemaVersion = CommonValue.NO_SCHEMA_VERSION; + String modelId = null; + String taskId = null; + + String featureId = null; + Float forecastValue = null; + Float lowerBound = null; + Float upperBound = null; + Instant forecastDataStartTime = null; + Instant forecastDataEndTime = null; + Integer horizonIndex = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case ForecastCommonName.FORECASTER_ID_KEY: + forecasterId = parser.text(); + break; + case CommonName.DATA_QUALITY_FIELD: + dataQuality = parser.doubleValue(); + break; + case CommonName.FEATURE_DATA_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + featureData = new ArrayList<>(); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + featureData.add(FeatureData.parse(parser)); + } + break; + case CommonName.DATA_START_TIME_FIELD: + dataStartTime = ParseUtils.toInstant(parser); + break; + case CommonName.DATA_END_TIME_FIELD: + dataEndTime = ParseUtils.toInstant(parser); + break; + case CommonName.EXECUTION_START_TIME_FIELD: + executionStartTime = ParseUtils.toInstant(parser); + break; + case CommonName.EXECUTION_END_TIME_FIELD: + executionEndTime = ParseUtils.toInstant(parser); + break; + case CommonName.ERROR_FIELD: + error = parser.text(); + break; + case CommonName.ENTITY_FIELD: + entity = Entity.parse(parser); + break; + case CommonName.USER_FIELD: + user = User.parse(parser); + break; + case CommonName.SCHEMA_VERSION_FIELD: + schemaVersion = parser.intValue(); + break; + case CommonName.MODEL_ID_FIELD: + modelId = parser.text(); + break; + case FEATURE_ID_FIELD: + featureId = parser.text(); + break; + case LOWER_BOUND_FIELD: + lowerBound = parser.floatValue(); + break; + case UPPER_BOUND_FIELD: + upperBound = parser.floatValue(); + break; + case VALUE_FIELD: + forecastValue = parser.floatValue(); + break; + case FORECAST_DATA_START_TIME_FIELD: + forecastDataStartTime = ParseUtils.toInstant(parser); + break; + case FORECAST_DATA_END_TIME_FIELD: + forecastDataEndTime = ParseUtils.toInstant(parser); + break; + case CommonName.TASK_ID_FIELD: + taskId = parser.text(); + break; + case HORIZON_INDEX_FIELD: + horizonIndex = parser.intValue(); + break; + default: + parser.skipChildren(); + break; + } + } + + return new ForecastResult( + forecasterId, + taskId, + dataQuality, + featureData, + dataStartTime, + dataEndTime, + executionStartTime, + executionEndTime, + error, + Optional.ofNullable(entity), + user, + schemaVersion, + modelId, + featureId, + forecastValue, + lowerBound, + upperBound, + forecastDataStartTime, + forecastDataEndTime, + horizonIndex + ); + } + + @Generated + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + if (!super.equals(o)) + return false; + ForecastResult that = (ForecastResult) o; + return Objects.equal(featureId, that.featureId) + && Objects.equal(dataQuality, that.dataQuality) + && Objects.equal(forecastValue, that.forecastValue) + && Objects.equal(lowerBound, that.lowerBound) + && Objects.equal(upperBound, that.upperBound) + && Objects.equal(confidenceIntervalWidth, that.confidenceIntervalWidth) + && Objects.equal(forecastDataStartTime, that.forecastDataStartTime) + && Objects.equal(forecastDataEndTime, that.forecastDataEndTime) + && Objects.equal(horizonIndex, that.horizonIndex); + } + + @Generated + @Override + public int hashCode() { + final int prime = 31; + int result = super.hashCode(); + result = prime * result + Objects + .hashCode( + featureId, + dataQuality, + forecastValue, + lowerBound, + upperBound, + confidenceIntervalWidth, + forecastDataStartTime, + forecastDataEndTime, + horizonIndex + ); + return result; + } + + @Generated + @Override + public String toString() { + return super.toString() + + ", " + + new ToStringBuilder(this) + .append("featureId", featureId) + .append("dataQuality", dataQuality) + .append("forecastValue", forecastValue) + .append("lowerBound", lowerBound) + .append("upperBound", upperBound) + .append("confidenceIntervalWidth", confidenceIntervalWidth) + .append("forecastDataStartTime", forecastDataStartTime) + .append("forecastDataEndTime", forecastDataEndTime) + .append("horizonIndex", horizonIndex) + .toString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + + out.writeOptionalString(featureId); + out.writeOptionalDouble(dataQuality); + out.writeOptionalFloat(forecastValue); + out.writeOptionalFloat(lowerBound); + out.writeOptionalFloat(upperBound); + out.writeOptionalFloat(confidenceIntervalWidth); + out.writeOptionalInstant(forecastDataStartTime); + out.writeOptionalInstant(forecastDataEndTime); + out.writeOptionalInt(horizonIndex); + } + + public static ForecastResult getDummyResult() { + return new ForecastResult( + DUMMY_FORECASTER_ID, + null, + null, + null, + null, + null, + null, + null, + Optional.empty(), + null, + CommonValue.NO_SCHEMA_VERSION, + null + ); + } + + /** + * Used to throw away requests when index pressure is high. + * @return when the error is there. + */ + @Override + public boolean isHighPriority() { + // AnomalyResult.toXContent won't record Double.NaN and thus make it null + return getError() != null; + } + + public Double getDataQuality() { + return dataQuality; + } + + public String getFeatureId() { + return featureId; + } + + public Float getForecastValue() { + return forecastValue; + } + + public Float getLowerBound() { + return lowerBound; + } + + public Float getUpperBound() { + return upperBound; + } + + public Float getConfidenceIntervalWidth() { + return confidenceIntervalWidth; + } + + public Instant getForecastDataStartTime() { + return forecastDataStartTime; + } + + public Instant getForecastDataEndTime() { + return forecastDataEndTime; + } + + public Integer getHorizonIndex() { + return horizonIndex; + } +} diff --git a/src/main/java/org/opensearch/forecast/model/Forecaster.java b/src/main/java/org/opensearch/forecast/model/Forecaster.java index cd17bf573..c572c28db 100644 --- a/src/main/java/org/opensearch/forecast/model/Forecaster.java +++ b/src/main/java/org/opensearch/forecast/model/Forecaster.java @@ -5,7 +5,7 @@ package org.opensearch.forecast.model; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.forecast.constant.ForecastCommonName.CUSTOM_RESULT_INDEX_PREFIX; import static org.opensearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder; @@ -16,12 +16,12 @@ import java.util.List; import java.util.Map; -import org.opensearch.common.ParsingException; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.unit.TimeValue; import org.opensearch.commons.authuser.User; import org.opensearch.core.ParseField; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParseException; diff --git a/src/main/java/org/opensearch/forecast/model/Forecaster.java-e b/src/main/java/org/opensearch/forecast/model/Forecaster.java-e new file mode 100644 index 000000000..e19428c0b --- /dev/null +++ b/src/main/java/org/opensearch/forecast/model/Forecaster.java-e @@ -0,0 +1,405 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.forecast.constant.ForecastCommonName.CUSTOM_RESULT_INDEX_PREFIX; +import static org.opensearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.ParseField; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParseException; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.settings.ForecastNumericSetting; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.constant.CommonValue; +import org.opensearch.timeseries.dataprocessor.ImputationOption; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.TimeConfiguration; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.ParseUtils; + +import com.google.common.base.Objects; + +/** + * Similar to AnomalyDetector, Forecaster defines config object. We cannot inherit from + * AnomalyDetector as AnomalyDetector uses detection interval but Forecaster doesn't + * need it and has to set it to null. Detection interval being null would fail + * AnomalyDetector's constructor because detection interval cannot be null. + */ +public class Forecaster extends Config { + public static final String FORECAST_PARSE_FIELD_NAME = "Forecaster"; + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( + Forecaster.class, + new ParseField(FORECAST_PARSE_FIELD_NAME), + it -> parse(it) + ); + + public static final String HORIZON_FIELD = "horizon"; + public static final String FORECAST_INTERVAL_FIELD = "forecast_interval"; + public static final int DEFAULT_HORIZON_SHINGLE_RATIO = 3; + + private Integer horizon; + + public Forecaster( + String forecasterId, + Long version, + String name, + String description, + String timeField, + List indices, + List features, + QueryBuilder filterQuery, + TimeConfiguration forecastInterval, + TimeConfiguration windowDelay, + Integer shingleSize, + Map uiMetadata, + Integer schemaVersion, + Instant lastUpdateTime, + List categoryFields, + User user, + String resultIndex, + Integer horizon, + ImputationOption imputationOption + ) { + super( + forecasterId, + version, + name, + description, + timeField, + indices, + features, + filterQuery, + windowDelay, + shingleSize, + uiMetadata, + schemaVersion, + lastUpdateTime, + categoryFields, + user, + resultIndex, + forecastInterval, + imputationOption + ); + + checkAndThrowValidationErrors(ValidationAspect.FORECASTER); + + if (forecastInterval == null) { + errorMessage = ForecastCommonMessages.NULL_FORECAST_INTERVAL; + issueType = ValidationIssueType.FORECAST_INTERVAL; + } else if (((IntervalTimeConfiguration) forecastInterval).getInterval() <= 0) { + errorMessage = ForecastCommonMessages.INVALID_FORECAST_INTERVAL; + issueType = ValidationIssueType.FORECAST_INTERVAL; + } + + int maxCategoryFields = ForecastNumericSetting.maxCategoricalFields(); + if (categoryFields != null && categoryFields.size() > maxCategoryFields) { + errorMessage = CommonMessages.getTooManyCategoricalFieldErr(maxCategoryFields); + issueType = ValidationIssueType.CATEGORY; + } + + if (invalidHorizon(horizon)) { + errorMessage = "Horizon size must be a positive integer no larger than " + + TimeSeriesSettings.MAX_SHINGLE_SIZE * DEFAULT_HORIZON_SHINGLE_RATIO + + ". Got " + + horizon; + issueType = ValidationIssueType.SHINGLE_SIZE_FIELD; + } + + checkAndThrowValidationErrors(ValidationAspect.FORECASTER); + + this.horizon = horizon; + } + + public Forecaster(StreamInput input) throws IOException { + super(input); + horizon = input.readInt(); + } + + @Override + public void writeTo(StreamOutput output) throws IOException { + super.writeTo(output); + output.writeInt(horizon); + } + + public boolean invalidHorizon(Integer horizonToTest) { + return horizonToTest != null + && (horizonToTest < 1 || horizonToTest > TimeSeriesSettings.MAX_SHINGLE_SIZE * DEFAULT_HORIZON_SHINGLE_RATIO); + } + + /** + * Parse raw json content into forecaster instance. + * + * @param parser json based content parser + * @return forecaster instance + * @throws IOException IOException if content can't be parsed correctly + */ + public static Forecaster parse(XContentParser parser) throws IOException { + return parse(parser, null); + } + + public static Forecaster parse(XContentParser parser, String forecasterId) throws IOException { + return parse(parser, forecasterId, null); + } + + /** + * Parse raw json content and given forecaster id into forecaster instance. + * + * @param parser json based content parser + * @param forecasterId forecaster id + * @param version forecaster document version + * @return forecaster instance + * @throws IOException IOException if content can't be parsed correctly + */ + public static Forecaster parse(XContentParser parser, String forecasterId, Long version) throws IOException { + return parse(parser, forecasterId, version, null, null); + } + + /** + * Parse raw json content and given forecaster id into forecaster instance. + * + * @param parser json based content parser + * @param forecasterId forecaster id + * @param version forecast document version + * @param defaultForecastInterval default forecaster interval + * @param defaultForecastWindowDelay default forecaster window delay + * @return forecaster instance + * @throws IOException IOException if content can't be parsed correctly + */ + public static Forecaster parse( + XContentParser parser, + String forecasterId, + Long version, + TimeValue defaultForecastInterval, + TimeValue defaultForecastWindowDelay + ) throws IOException { + String name = null; + String description = ""; + String timeField = null; + List indices = new ArrayList(); + QueryBuilder filterQuery = QueryBuilders.matchAllQuery(); + TimeConfiguration forecastInterval = defaultForecastInterval == null + ? null + : new IntervalTimeConfiguration(defaultForecastInterval.getMinutes(), ChronoUnit.MINUTES); + TimeConfiguration windowDelay = defaultForecastWindowDelay == null + ? null + : new IntervalTimeConfiguration(defaultForecastWindowDelay.getSeconds(), ChronoUnit.SECONDS); + Integer shingleSize = null; + List features = new ArrayList<>(); + Integer schemaVersion = CommonValue.NO_SCHEMA_VERSION; + Map uiMetadata = null; + Instant lastUpdateTime = null; + User user = null; + String resultIndex = null; + + List categoryField = null; + Integer horizon = null; + ImputationOption interpolationOption = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case NAME_FIELD: + name = parser.text(); + break; + case DESCRIPTION_FIELD: + description = parser.text(); + break; + case TIMEFIELD_FIELD: + timeField = parser.text(); + break; + case INDICES_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + indices.add(parser.text()); + } + break; + case UI_METADATA_FIELD: + uiMetadata = parser.map(); + break; + case CommonName.SCHEMA_VERSION_FIELD: + schemaVersion = parser.intValue(); + break; + case FILTER_QUERY_FIELD: + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + try { + filterQuery = parseInnerQueryBuilder(parser); + } catch (ParsingException | XContentParseException e) { + throw new ValidationException( + "Custom query error in data filter: " + e.getMessage(), + ValidationIssueType.FILTER_QUERY, + ValidationAspect.FORECASTER + ); + } catch (IllegalArgumentException e) { + if (!e.getMessage().contains("empty clause")) { + throw e; + } + } + break; + case FORECAST_INTERVAL_FIELD: + try { + forecastInterval = TimeConfiguration.parse(parser); + } catch (Exception e) { + if (e instanceof IllegalArgumentException && e.getMessage().contains(CommonMessages.NEGATIVE_TIME_CONFIGURATION)) { + throw new ValidationException( + "Forecasting interval must be a positive integer", + ValidationIssueType.FORECAST_INTERVAL, + ValidationAspect.FORECASTER + ); + } + throw e; + } + break; + case FEATURE_ATTRIBUTES_FIELD: + try { + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + features.add(Feature.parse(parser)); + } + } catch (Exception e) { + if (e instanceof ParsingException || e instanceof XContentParseException) { + throw new ValidationException( + "Custom query error: " + e.getMessage(), + ValidationIssueType.FEATURE_ATTRIBUTES, + ValidationAspect.FORECASTER + ); + } + throw e; + } + break; + case WINDOW_DELAY_FIELD: + try { + windowDelay = TimeConfiguration.parse(parser); + } catch (Exception e) { + if (e instanceof IllegalArgumentException && e.getMessage().contains(CommonMessages.NEGATIVE_TIME_CONFIGURATION)) { + throw new ValidationException( + "Window delay interval must be a positive integer", + ValidationIssueType.WINDOW_DELAY, + ValidationAspect.FORECASTER + ); + } + throw e; + } + break; + case SHINGLE_SIZE_FIELD: + shingleSize = parser.intValue(); + break; + case LAST_UPDATE_TIME_FIELD: + lastUpdateTime = ParseUtils.toInstant(parser); + break; + case CATEGORY_FIELD: + categoryField = (List) parser.list(); + break; + case USER_FIELD: + user = User.parse(parser); + break; + case RESULT_INDEX_FIELD: + resultIndex = parser.text(); + break; + case HORIZON_FIELD: + horizon = parser.intValue(); + break; + case IMPUTATION_OPTION_FIELD: + interpolationOption = ImputationOption.parse(parser); + break; + default: + parser.skipChildren(); + break; + } + } + Forecaster forecaster = new Forecaster( + forecasterId, + version, + name, + description, + timeField, + indices, + features, + filterQuery, + forecastInterval, + windowDelay, + getShingleSize(shingleSize), + uiMetadata, + schemaVersion, + lastUpdateTime, + categoryField, + user, + resultIndex, + horizon, + interpolationOption + ); + return forecaster; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + xContentBuilder = super.toXContent(xContentBuilder, params); + xContentBuilder.field(FORECAST_INTERVAL_FIELD, interval).field(HORIZON_FIELD, horizon); + + return xContentBuilder.endObject(); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + Forecaster forecaster = (Forecaster) o; + return super.equals(o) && Objects.equal(horizon, forecaster.horizon); + } + + @Override + public int hashCode() { + int hash = super.hashCode(); + hash = 89 * hash + (this.horizon != null ? this.horizon.hashCode() : 0); + return hash; + } + + @Override + public String validateCustomResultIndex(String resultIndex) { + if (resultIndex != null && !resultIndex.startsWith(CUSTOM_RESULT_INDEX_PREFIX)) { + return ForecastCommonMessages.INVALID_RESULT_INDEX_PREFIX; + } + return super.validateCustomResultIndex(resultIndex); + } + + @Override + protected ValidationAspect getConfigValidationAspect() { + return ValidationAspect.FORECASTER; + } + + public Integer getHorizon() { + return horizon; + } +} diff --git a/src/main/java/org/opensearch/forecast/settings/ForecastEnabledSetting.java-e b/src/main/java/org/opensearch/forecast/settings/ForecastEnabledSetting.java-e new file mode 100644 index 000000000..1db9bf340 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/settings/ForecastEnabledSetting.java-e @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.settings; + +import static java.util.Collections.unmodifiableMap; +import static org.opensearch.common.settings.Setting.Property.Dynamic; +import static org.opensearch.common.settings.Setting.Property.NodeScope; + +import java.util.HashMap; +import java.util.Map; + +import org.opensearch.common.settings.Setting; +import org.opensearch.timeseries.settings.DynamicNumericSetting; + +public class ForecastEnabledSetting extends DynamicNumericSetting { + + /** + * Singleton instance + */ + private static ForecastEnabledSetting INSTANCE; + + /** + * Settings name + */ + public static final String FORECAST_ENABLED = "plugins.forecast.enabled"; + + public static final String FORECAST_BREAKER_ENABLED = "plugins.forecast.breaker.enabled"; + + public static final String FORECAST_DOOR_KEEPER_IN_CACHE_ENABLED = "plugins.forecast.door_keeper_in_cache.enabled";; + + public static final Map> settings = unmodifiableMap(new HashMap>() { + { + /** + * forecast enable/disable setting + */ + put(FORECAST_ENABLED, Setting.boolSetting(FORECAST_ENABLED, true, NodeScope, Dynamic)); + + /** + * forecast breaker enable/disable setting + */ + put(FORECAST_BREAKER_ENABLED, Setting.boolSetting(FORECAST_BREAKER_ENABLED, true, NodeScope, Dynamic)); + + /** + * We have a bloom filter placed in front of inactive entity cache to + * filter out unpopular items that are not likely to appear more + * than once. Whether this bloom filter is enabled or not. + */ + put( + FORECAST_DOOR_KEEPER_IN_CACHE_ENABLED, + Setting.boolSetting(FORECAST_DOOR_KEEPER_IN_CACHE_ENABLED, false, NodeScope, Dynamic) + ); + } + }); + + private ForecastEnabledSetting(Map> settings) { + super(settings); + } + + public static synchronized ForecastEnabledSetting getInstance() { + if (INSTANCE == null) { + INSTANCE = new ForecastEnabledSetting(settings); + } + return INSTANCE; + } + + /** + * Whether forecasting is enabled. If disabled, time series plugin rejects RESTful requests about forecasting and stop all forecasting jobs. + * @return whether forecasting is enabled. + */ + public static boolean isForecastEnabled() { + return ForecastEnabledSetting.getInstance().getSettingValue(ForecastEnabledSetting.FORECAST_ENABLED); + } + + /** + * Whether forecast circuit breaker is enabled or not. If disabled, an open circuit breaker wouldn't cause an forecast job to be stopped. + * @return whether forecast circuit breaker is enabled or not. + */ + public static boolean isForecastBreakerEnabled() { + return ForecastEnabledSetting.getInstance().getSettingValue(ForecastEnabledSetting.FORECAST_BREAKER_ENABLED); + } + + /** + * If enabled, we filter out unpopular items that are not likely to appear more than once + * @return wWhether door keeper in cache is enabled or not. + */ + public static boolean isDoorKeeperInCacheEnabled() { + return ForecastEnabledSetting.getInstance().getSettingValue(ForecastEnabledSetting.FORECAST_DOOR_KEEPER_IN_CACHE_ENABLED); + } +} diff --git a/src/main/java/org/opensearch/forecast/settings/ForecastNumericSetting.java-e b/src/main/java/org/opensearch/forecast/settings/ForecastNumericSetting.java-e new file mode 100644 index 000000000..271321575 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/settings/ForecastNumericSetting.java-e @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.settings; + +import static java.util.Collections.unmodifiableMap; + +import java.util.HashMap; +import java.util.Map; + +import org.opensearch.common.settings.Setting; +import org.opensearch.timeseries.settings.DynamicNumericSetting; + +public class ForecastNumericSetting extends DynamicNumericSetting { + /** + * Singleton instance + */ + private static ForecastNumericSetting INSTANCE; + + /** + * Settings name + */ + public static final String CATEGORY_FIELD_LIMIT = "plugins.forecast.category_field_limit"; + + private static final Map> settings = unmodifiableMap(new HashMap>() { + { + // how many categorical fields we support + // The number of category field won't causes correctness issues for our + // implementation, but can cause performance issues. The more categorical + // fields, the larger of the forecast results, intermediate states, and + // more expensive queries (e.g., to get top entities in preview API, we need + // to use scripts in terms aggregation. The more fields, the slower the query). + put( + CATEGORY_FIELD_LIMIT, + Setting.intSetting(CATEGORY_FIELD_LIMIT, 2, 0, 5, Setting.Property.NodeScope, Setting.Property.Dynamic) + ); + } + }); + + ForecastNumericSetting(Map> settings) { + super(settings); + } + + public static synchronized ForecastNumericSetting getInstance() { + if (INSTANCE == null) { + INSTANCE = new ForecastNumericSetting(settings); + } + return INSTANCE; + } + + /** + * @return the max number of categorical fields + */ + public static int maxCategoricalFields() { + return ForecastNumericSetting.getInstance().getSettingValue(ForecastNumericSetting.CATEGORY_FIELD_LIMIT); + } +} diff --git a/src/main/java/org/opensearch/forecast/settings/ForecastSettings.java b/src/main/java/org/opensearch/forecast/settings/ForecastSettings.java index b43543cbf..a9d033b78 100644 --- a/src/main/java/org/opensearch/forecast/settings/ForecastSettings.java +++ b/src/main/java/org/opensearch/forecast/settings/ForecastSettings.java @@ -31,7 +31,7 @@ public final class ForecastSettings { // ====================================== // restful apis // ====================================== - public static final Setting REQUEST_TIMEOUT = Setting + public static final Setting FORECAST_REQUEST_TIMEOUT = Setting .positiveTimeSetting( "plugins.forecast.request_timeout", TimeValue.timeValueSeconds(10), @@ -113,10 +113,31 @@ public final class ForecastSettings { public static final Setting FORECAST_MAX_PRIMARY_SHARDS = Setting .intSetting("plugins.forecast.max_primary_shards", 20, 0, 200, Setting.Property.NodeScope, Setting.Property.Dynamic); + // saving checkpoint every 12 hours. + // To support 1 million entities in 36 data nodes, each node has roughly 28K models. + // In each hour, we roughly need to save 2400 models. Since each model saving can + // take about 1 seconds (default value of FORECAST_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS) + // we can use up to 2400 seconds to finish saving checkpoints. + public static final Setting FORECAST_CHECKPOINT_SAVING_FREQ = Setting + .positiveTimeSetting( + "plugins.forecast.checkpoint_saving_freq", + TimeValue.timeValueHours(12), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting FORECAST_CHECKPOINT_TTL = Setting + .positiveTimeSetting( + "plugins.forecast.checkpoint_ttl", + TimeValue.timeValueDays(7), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + // ====================================== // Security // ====================================== - public static final Setting FILTER_BY_BACKEND_ROLES = Setting + public static final Setting FORECAST_FILTER_BY_BACKEND_ROLES = Setting .boolSetting("plugins.forecast.filter_by_backend_roles", false, Setting.Property.NodeScope, Setting.Property.Dynamic); // ====================================== @@ -217,6 +238,62 @@ public final class ForecastSettings { Setting.Property.Dynamic ); + // the percentage of heap usage allowed for queues holding large requests + // set it to 0 to disable the queue + public static final Setting FORECAST_CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT = Setting + .floatSetting( + "plugins.forecast.checkpoint_write_queue_max_heap_percent", + 0.01f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting FORECAST_CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT = Setting + .floatSetting( + "plugins.forecast.checkpoint_maintain_queue_max_heap_percent", + 0.001f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting FORECAST_COLD_START_QUEUE_MAX_HEAP_PERCENT = Setting + .floatSetting( + "plugins.forecast.cold_start_queue_max_heap_percent", + 0.001f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting FORECAST_RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT = Setting + .floatSetting( + "plugins.forecast.result_write_queue_max_heap_percent", + 0.01f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting FORECAST_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT = Setting + .floatSetting( + "plugins.forecast.checkpoint_read_queue_max_heap_percent", + 0.001f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting FORECAST_COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT = Setting + .floatSetting( + "plugins.forecast.cold_entity_queue_max_heap_percent", + 0.001f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + // ====================================== // fault tolerance // ====================================== @@ -239,27 +316,28 @@ public final class ForecastSettings { Setting.Property.Dynamic ); - public static final Setting FORECAST_MAX_RETRY_FOR_UNRESPONSIVE_NODE = Setting - .intSetting("plugins.forecast.max_retry_for_unresponsive_node", 5, 0, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting FORECAST_MAX_RETRY_FOR_END_RUN_EXCEPTION = Setting + .intSetting("plugins.forecast.max_retry_for_end_run_exception", 6, 0, Setting.Property.NodeScope, Setting.Property.Dynamic); // ====================================== // cache related parameters // ====================================== /* * Opensearch-only setting - * Each detector has its dedicated cache that stores ten entities' states per node. - * A detector's hottest entities load their states into the dedicated cache. - * Other detectors cannot use space reserved by a detector's dedicated cache. + * Each forecaster has its dedicated cache that stores ten entities' states per node for HC + * and one entity' state per node for single-stream forecaster. + * A forecaster's hottest entities load their states into the dedicated cache. + * Other forecasters cannot use space reserved by a forecaster's dedicated cache. * DEDICATED_CACHE_SIZE is a setting to make dedicated cache's size flexible. * When that setting is changed, if the size decreases, we will release memory - * if required (e.g., when a user also decreased AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE, - * the max memory percentage that AD can use); + * if required (e.g., when a user also decreased ForecastSettings.FORECAST_MODEL_MAX_SIZE_PERCENTAGE, + * the max memory percentage that forecasting plugin can use); * if the size increases, we may reject the setting change if we cannot fulfill - * that request (e.g., when it will uses more memory than allowed for AD). + * that request (e.g., when it will uses more memory than allowed for Forecasting). * * With compact rcf, rcf with 30 trees and shingle size 4 is of 500KB. * The recommended max heap size is 32 GB. Even if users use all of the heap - * for AD, the max number of entity model cannot surpass + * for Forecasting, the max number of entity model cannot surpass * 3.2 GB/500KB = 3.2 * 10^10 / 5*10^5 = 6.4 * 10 ^4 * where 3.2 GB is from 10% memory limit of AD plugin. * That's why I am using 60_000 as the max limit. @@ -268,6 +346,33 @@ public final class ForecastSettings { .intSetting("plugins.forecast.dedicated_cache_size", 10, 0, 60_000, Setting.Property.NodeScope, Setting.Property.Dynamic); public static final Setting FORECAST_MODEL_MAX_SIZE_PERCENTAGE = Setting - .doubleSetting("plugins.forecast.model_max_size_percent", 0.1, 0, 0.7, Setting.Property.NodeScope, Setting.Property.Dynamic); + .doubleSetting("plugins.forecast.model_max_size_percent", 0.1, 0, 0.9, Setting.Property.NodeScope, Setting.Property.Dynamic); + + // ====================================== + // pagination setting + // ====================================== + // pagination size + public static final Setting FORECAST_PAGE_SIZE = Setting + .intSetting("plugins.forecast.page_size", 1_000, 0, 10_000, Setting.Property.NodeScope, Setting.Property.Dynamic); + + // Increase the value will adding pressure to indexing anomaly results and our feature query + // OpenSearch-only setting as previous the legacy default is too low (1000) + public static final Setting FORECAST_MAX_ENTITIES_PER_INTERVAL = Setting + .intSetting( + "plugins.forecast.max_entities_per_interval", + 1_000_000, + 0, + 2_000_000, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // ====================================== + // stats/profile API setting + // ====================================== + // the max number of models to return per node. + // the setting is used to limit resource usage due to showing models + public static final Setting FORECAST_MAX_MODEL_SIZE_PER_NODE = Setting + .intSetting("plugins.forecast.max_model_size_per_node", 100, 1, 10_000, Setting.Property.NodeScope, Setting.Property.Dynamic); } diff --git a/src/main/java/org/opensearch/forecast/settings/ForecastSettings.java-e b/src/main/java/org/opensearch/forecast/settings/ForecastSettings.java-e new file mode 100644 index 000000000..a9d033b78 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/settings/ForecastSettings.java-e @@ -0,0 +1,378 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.settings; + +import org.opensearch.common.settings.Setting; +import org.opensearch.common.unit.TimeValue; + +public final class ForecastSettings { + // ====================================== + // config parameters + // ====================================== + public static final Setting FORECAST_INTERVAL = Setting + .positiveTimeSetting( + "plugins.forecast.default_interval", + TimeValue.timeValueMinutes(10), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting FORECAST_WINDOW_DELAY = Setting + .timeSetting( + "plugins.forecast.default_window_delay", + TimeValue.timeValueMinutes(0), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // ====================================== + // restful apis + // ====================================== + public static final Setting FORECAST_REQUEST_TIMEOUT = Setting + .positiveTimeSetting( + "plugins.forecast.request_timeout", + TimeValue.timeValueSeconds(10), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // ====================================== + // resource constraint + // ====================================== + public static final Setting MAX_SINGLE_STREAM_FORECASTERS = Setting + .intSetting("plugins.forecast.max_forecasters", 1000, 0, 10_000, Setting.Property.NodeScope, Setting.Property.Dynamic); + + public static final Setting MAX_HC_FORECASTERS = Setting + .intSetting("plugins.forecast.max_hc_forecasters", 10, 0, 10_000, Setting.Property.NodeScope, Setting.Property.Dynamic); + + // save partial zero-anomaly grade results after indexing pressure reaching the limit + // Opendistro version has similar setting. I lowered the value to make room + // for INDEX_PRESSURE_HARD_LIMIT. I don't find a floatSetting that has both default + // and fallback values. I want users to use the new default value 0.6 instead of 0.8. + // So do not plan to use the value of legacy setting as fallback. + public static final Setting FORECAST_INDEX_PRESSURE_SOFT_LIMIT = Setting + .floatSetting("plugins.forecast.index_pressure_soft_limit", 0.6f, 0.0f, Setting.Property.NodeScope, Setting.Property.Dynamic); + + // save only error or larger-than-one anomaly grade results after indexing + // pressure reaching the limit + // opensearch-only setting + public static final Setting FORECAST_INDEX_PRESSURE_HARD_LIMIT = Setting + .floatSetting("plugins.forecast.index_pressure_hard_limit", 0.9f, 0.0f, Setting.Property.NodeScope, Setting.Property.Dynamic); + + // we only allow single feature forecast now + public static final int MAX_FORECAST_FEATURES = 1; + + // ====================================== + // AD Index setting + // ====================================== + public static int FORECAST_MAX_UPDATE_RETRY_TIMES = 10_000; + + // ====================================== + // Indices + // ====================================== + public static final Setting FORECAST_RESULT_HISTORY_MAX_DOCS_PER_SHARD = Setting + .longSetting( + "plugins.forecast.forecast_result_history_max_docs_per_shard", + // Total documents in the primary shards. + // Note the count is for Lucene docs. Lucene considers a nested + // doc a doc too. One result on average equals to 4 Lucene docs. + // A single Lucene doc is roughly 46.8 bytes (measured by experiments). + // 1.35 billion docs is about 65 GB. One shard can have at most 65 GB. + // This number in Lucene doc count is used in RolloverRequest#addMaxIndexDocsCondition + // for adding condition to check if the index has at least numDocs. + 1_350_000_000L, + 0L, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting FORECAST_RESULT_HISTORY_RETENTION_PERIOD = Setting + .positiveTimeSetting( + "plugins.forecast.forecast_result_history_retention_period", + TimeValue.timeValueDays(30), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting FORECAST_RESULT_HISTORY_ROLLOVER_PERIOD = Setting + .positiveTimeSetting( + "plugins.forecast.forecast_result_history_rollover_period", + TimeValue.timeValueHours(12), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final String FORECAST_RESULTS_INDEX_MAPPING_FILE = "mappings/forecast-results.json"; + public static final String FORECAST_STATE_INDEX_MAPPING_FILE = "mappings/forecast-state.json"; + public static final String FORECAST_CHECKPOINT_INDEX_MAPPING_FILE = "mappings/forecast-checkpoint.json"; + + // max number of primary shards of a forecast index + public static final Setting FORECAST_MAX_PRIMARY_SHARDS = Setting + .intSetting("plugins.forecast.max_primary_shards", 20, 0, 200, Setting.Property.NodeScope, Setting.Property.Dynamic); + + // saving checkpoint every 12 hours. + // To support 1 million entities in 36 data nodes, each node has roughly 28K models. + // In each hour, we roughly need to save 2400 models. Since each model saving can + // take about 1 seconds (default value of FORECAST_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS) + // we can use up to 2400 seconds to finish saving checkpoints. + public static final Setting FORECAST_CHECKPOINT_SAVING_FREQ = Setting + .positiveTimeSetting( + "plugins.forecast.checkpoint_saving_freq", + TimeValue.timeValueHours(12), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting FORECAST_CHECKPOINT_TTL = Setting + .positiveTimeSetting( + "plugins.forecast.checkpoint_ttl", + TimeValue.timeValueDays(7), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // ====================================== + // Security + // ====================================== + public static final Setting FORECAST_FILTER_BY_BACKEND_ROLES = Setting + .boolSetting("plugins.forecast.filter_by_backend_roles", false, Setting.Property.NodeScope, Setting.Property.Dynamic); + + // ====================================== + // Task + // ====================================== + public static int MAX_OLD_FORECAST_TASK_DOCS = 1000; + + public static final Setting MAX_OLD_TASK_DOCS_PER_FORECASTER = Setting + .intSetting( + "plugins.forecast.max_old_task_docs_per_forecaster", + // One forecast task is roughly 1.5KB for normal case. Suppose task's size + // is 2KB conservatively. If we store 1000 forecast tasks for one forecaster, + // that will be 2GB. + 1, + 1, // keep at least 1 old task per forecaster + MAX_OLD_FORECAST_TASK_DOCS, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // Maximum number of deleted tasks can keep in cache. + public static final Setting MAX_CACHED_DELETED_TASKS = Setting + .intSetting("plugins.forecast.max_cached_deleted_tasks", 1000, 1, 10_000, Setting.Property.NodeScope, Setting.Property.Dynamic); + + // ====================================== + // rate-limiting queue parameters + // ====================================== + /** + * ES recommends bulk size to be 5~15 MB. + * ref: https://tinyurl.com/3zdbmbwy + * Assume each checkpoint takes roughly 200KB. 25 requests are of 5 MB. + */ + public static final Setting FORECAST_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE = Setting + .intSetting("plugins.forecast.checkpoint_write_queue_batch_size", 25, 1, 60, Setting.Property.NodeScope, Setting.Property.Dynamic); + + // expected execution time per checkpoint maintain request. This setting controls + // the speed of checkpoint maintenance execution. The larger, the faster, and + // the more performance impact to customers' workload. + public static final Setting FORECAST_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS = Setting + .intSetting( + "plugins.forecast.expected_checkpoint_maintain_time_in_millisecs", + 1000, + 0, + 3600000, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * Max concurrent checkpoint writes per node + */ + public static final Setting FORECAST_CHECKPOINT_WRITE_QUEUE_CONCURRENCY = Setting + .intSetting("plugins.forecast.checkpoint_write_queue_concurrency", 2, 1, 10, Setting.Property.NodeScope, Setting.Property.Dynamic); + + /** + * Max concurrent cold starts per node + */ + public static final Setting FORECAST_COLD_START_QUEUE_CONCURRENCY = Setting + .intSetting("plugins.forecast.cold_start_queue_concurrency", 1, 1, 10, Setting.Property.NodeScope, Setting.Property.Dynamic); + + /** + * Max concurrent result writes per node. Since checkpoint is relatively large + * (250KB), we have 2 concurrent threads processing the queue. + */ + public static final Setting FORECAST_RESULT_WRITE_QUEUE_CONCURRENCY = Setting + .intSetting("plugins.forecast.result_write_queue_concurrency", 2, 1, 10, Setting.Property.NodeScope, Setting.Property.Dynamic); + + /** + * ES recommends bulk size to be 5~15 MB. + * ref: https://tinyurl.com/3zdbmbwy + * Assume each result takes roughly 1KB. 5000 requests are of 5 MB. + */ + public static final Setting FORECAST_RESULT_WRITE_QUEUE_BATCH_SIZE = Setting + .intSetting("plugins.forecast.result_write_queue_batch_size", 5000, 1, 15000, Setting.Property.NodeScope, Setting.Property.Dynamic); + + /** + * Max concurrent checkpoint reads per node + */ + public static final Setting FORECAST_CHECKPOINT_READ_QUEUE_CONCURRENCY = Setting + .intSetting("plugins.forecast.checkpoint_read_queue_concurrency", 1, 1, 10, Setting.Property.NodeScope, Setting.Property.Dynamic); + + /** + * Assume each checkpoint takes roughly 200KB. 25 requests are of 5 MB. + */ + public static final Setting FORECAST_CHECKPOINT_READ_QUEUE_BATCH_SIZE = Setting + .intSetting("plugins.forecast.checkpoint_read_queue_batch_size", 25, 1, 60, Setting.Property.NodeScope, Setting.Property.Dynamic); + + // expected execution time per cold entity request. This setting controls + // the speed of cold entity requests execution. The larger, the faster, and + // the more performance impact to customers' workload. + public static final Setting FORECAST_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS = Setting + .intSetting( + "plugins.forecast.expected_cold_entity_execution_time_in_millisecs", + 3000, + 0, + 3600000, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // the percentage of heap usage allowed for queues holding large requests + // set it to 0 to disable the queue + public static final Setting FORECAST_CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT = Setting + .floatSetting( + "plugins.forecast.checkpoint_write_queue_max_heap_percent", + 0.01f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting FORECAST_CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT = Setting + .floatSetting( + "plugins.forecast.checkpoint_maintain_queue_max_heap_percent", + 0.001f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting FORECAST_COLD_START_QUEUE_MAX_HEAP_PERCENT = Setting + .floatSetting( + "plugins.forecast.cold_start_queue_max_heap_percent", + 0.001f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting FORECAST_RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT = Setting + .floatSetting( + "plugins.forecast.result_write_queue_max_heap_percent", + 0.01f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting FORECAST_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT = Setting + .floatSetting( + "plugins.forecast.checkpoint_read_queue_max_heap_percent", + 0.001f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting FORECAST_COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT = Setting + .floatSetting( + "plugins.forecast.cold_entity_queue_max_heap_percent", + 0.001f, + 0.0f, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // ====================================== + // fault tolerance + // ====================================== + public static final Setting FORECAST_BACKOFF_INITIAL_DELAY = Setting + .positiveTimeSetting( + "plugins.forecast.backoff_initial_delay", + TimeValue.timeValueMillis(1000), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting FORECAST_MAX_RETRY_FOR_BACKOFF = Setting + .intSetting("plugins.forecast.max_retry_for_backoff", 3, 0, Setting.Property.NodeScope, Setting.Property.Dynamic); + + public static final Setting FORECAST_BACKOFF_MINUTES = Setting + .positiveTimeSetting( + "plugins.forecast.backoff_minutes", + TimeValue.timeValueMinutes(15), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting FORECAST_MAX_RETRY_FOR_END_RUN_EXCEPTION = Setting + .intSetting("plugins.forecast.max_retry_for_end_run_exception", 6, 0, Setting.Property.NodeScope, Setting.Property.Dynamic); + + // ====================================== + // cache related parameters + // ====================================== + /* + * Opensearch-only setting + * Each forecaster has its dedicated cache that stores ten entities' states per node for HC + * and one entity' state per node for single-stream forecaster. + * A forecaster's hottest entities load their states into the dedicated cache. + * Other forecasters cannot use space reserved by a forecaster's dedicated cache. + * DEDICATED_CACHE_SIZE is a setting to make dedicated cache's size flexible. + * When that setting is changed, if the size decreases, we will release memory + * if required (e.g., when a user also decreased ForecastSettings.FORECAST_MODEL_MAX_SIZE_PERCENTAGE, + * the max memory percentage that forecasting plugin can use); + * if the size increases, we may reject the setting change if we cannot fulfill + * that request (e.g., when it will uses more memory than allowed for Forecasting). + * + * With compact rcf, rcf with 30 trees and shingle size 4 is of 500KB. + * The recommended max heap size is 32 GB. Even if users use all of the heap + * for Forecasting, the max number of entity model cannot surpass + * 3.2 GB/500KB = 3.2 * 10^10 / 5*10^5 = 6.4 * 10 ^4 + * where 3.2 GB is from 10% memory limit of AD plugin. + * That's why I am using 60_000 as the max limit. + */ + public static final Setting FORECAST_DEDICATED_CACHE_SIZE = Setting + .intSetting("plugins.forecast.dedicated_cache_size", 10, 0, 60_000, Setting.Property.NodeScope, Setting.Property.Dynamic); + + public static final Setting FORECAST_MODEL_MAX_SIZE_PERCENTAGE = Setting + .doubleSetting("plugins.forecast.model_max_size_percent", 0.1, 0, 0.9, Setting.Property.NodeScope, Setting.Property.Dynamic); + + // ====================================== + // pagination setting + // ====================================== + // pagination size + public static final Setting FORECAST_PAGE_SIZE = Setting + .intSetting("plugins.forecast.page_size", 1_000, 0, 10_000, Setting.Property.NodeScope, Setting.Property.Dynamic); + + // Increase the value will adding pressure to indexing anomaly results and our feature query + // OpenSearch-only setting as previous the legacy default is too low (1000) + public static final Setting FORECAST_MAX_ENTITIES_PER_INTERVAL = Setting + .intSetting( + "plugins.forecast.max_entities_per_interval", + 1_000_000, + 0, + 2_000_000, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // ====================================== + // stats/profile API setting + // ====================================== + // the max number of models to return per node. + // the setting is used to limit resource usage due to showing models + public static final Setting FORECAST_MAX_MODEL_SIZE_PER_NODE = Setting + .intSetting("plugins.forecast.max_model_size_per_node", 100, 1, 10_000, Setting.Property.NodeScope, Setting.Property.Dynamic); + +} diff --git a/src/main/java/org/opensearch/timeseries/Name.java-e b/src/main/java/org/opensearch/timeseries/Name.java-e new file mode 100644 index 000000000..d53a2a33a --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/Name.java-e @@ -0,0 +1,33 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries; + +import java.util.Collection; +import java.util.HashSet; +import java.util.Set; +import java.util.function.Function; + +/** + * A super type for enum types returning names + * + */ +public interface Name { + String getName(); + + static Set getNameFromCollection(Collection names, Function getName) { + Set res = new HashSet<>(); + for (String name : names) { + res.add(getName.apply(name)); + } + return res; + } +} diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java b/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java similarity index 97% rename from src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java rename to src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java index 774e79fa8..9d3e827eb 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java +++ b/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad; +package org.opensearch.timeseries; import static java.util.Collections.unmodifiableList; @@ -34,6 +34,11 @@ import org.opensearch.SpecialPermission; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionResponse; +import org.opensearch.ad.AnomalyDetectorJobRunner; +import org.opensearch.ad.AnomalyDetectorRunner; +import org.opensearch.ad.ExecuteADResultResponseRecorder; +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.ad.caching.CacheProvider; import org.opensearch.ad.caching.EntityCache; @@ -161,7 +166,6 @@ import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.io.stream.NamedWriteableRegistry; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.IndexScopedSettings; import org.opensearch.common.settings.Setting; @@ -169,9 +173,10 @@ import org.opensearch.common.settings.SettingsFilter; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; -import org.opensearch.common.xcontent.XContentParserUtils; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.xcontent.XContentParserUtils; import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; import org.opensearch.forecast.model.Forecaster; @@ -216,10 +221,11 @@ /** * Entry point of AD plugin. */ -public class AnomalyDetectorPlugin extends Plugin implements ActionPlugin, ScriptPlugin, JobSchedulerExtension { +public class TimeSeriesAnalyticsPlugin extends Plugin implements ActionPlugin, ScriptPlugin, JobSchedulerExtension { - private static final Logger LOG = LogManager.getLogger(AnomalyDetectorPlugin.class); + private static final Logger LOG = LogManager.getLogger(TimeSeriesAnalyticsPlugin.class); + // AD constants public static final String LEGACY_AD_BASE = "/_opendistro/_anomaly_detection"; public static final String LEGACY_OPENDISTRO_AD_BASE_URI = LEGACY_AD_BASE + "/detectors"; public static final String AD_BASE_URI = "/_plugins/_anomaly_detection"; @@ -227,7 +233,16 @@ public class AnomalyDetectorPlugin extends Plugin implements ActionPlugin, Scrip public static final String AD_THREAD_POOL_PREFIX = "opensearch.ad."; public static final String AD_THREAD_POOL_NAME = "ad-threadpool"; public static final String AD_BATCH_TASK_THREAD_POOL_NAME = "ad-batch-task-threadpool"; - public static final String AD_JOB_TYPE = "opendistro_anomaly_detector"; + + // forecasting constants + public static final String FORECAST_BASE_URI = "/_plugins/_forecast"; + public static final String FORECAST_FORECASTERS_URI = FORECAST_BASE_URI + "/forecasters"; + public static final String FORECAST_THREAD_POOL_PREFIX = "opensearch.forecast."; + public static final String FORECAST_THREAD_POOL_NAME = "forecast-threadpool"; + public static final String FORECAST_BATCH_TASK_THREAD_POOL_NAME = "forecast-batch-task-threadpool"; + + public static final String TIME_SERIES_JOB_TYPE = "opensearch_time_series_analytics"; + private static Gson gson; private ADIndexManagement anomalyDetectionIndices; private AnomalyDetectorRunner anomalyDetectorRunner; @@ -250,10 +265,10 @@ public class AnomalyDetectorPlugin extends Plugin implements ActionPlugin, Scrip SpecialPermission.check(); // gson intialization requires "java.lang.RuntimePermission" "accessDeclaredMembers" to // initialize ConstructorConstructor - AccessController.doPrivileged((PrivilegedAction) AnomalyDetectorPlugin::initGson); + AccessController.doPrivileged((PrivilegedAction) TimeSeriesAnalyticsPlugin::initGson); } - public AnomalyDetectorPlugin() {} + public TimeSeriesAnalyticsPlugin() {} @Override public List getRestHandlers( @@ -1029,7 +1044,7 @@ public List getNamedXContent() { @Override public String getJobType() { - return AD_JOB_TYPE; + return TIME_SERIES_JOB_TYPE; } @Override diff --git a/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java-e b/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java-e new file mode 100644 index 000000000..9a7236460 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java-e @@ -0,0 +1,1083 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries; + +import static java.util.Collections.unmodifiableList; + +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.time.Clock; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.apache.commons.pool2.BasePooledObjectFactory; +import org.apache.commons.pool2.PooledObject; +import org.apache.commons.pool2.impl.DefaultPooledObject; +import org.apache.commons.pool2.impl.GenericObjectPool; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.SpecialPermission; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionResponse; +import org.opensearch.ad.AnomalyDetectorJobRunner; +import org.opensearch.ad.AnomalyDetectorRunner; +import org.opensearch.ad.ExecuteADResultResponseRecorder; +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.caching.PriorityCache; +import org.opensearch.ad.cluster.ADClusterEventListener; +import org.opensearch.ad.cluster.ADDataMigrator; +import org.opensearch.ad.cluster.ClusterManagerEventListener; +import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.CheckpointDao; +import org.opensearch.ad.ml.EntityColdStarter; +import org.opensearch.ad.ml.HybridThresholdingModel; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.model.DetectorInternalState; +import org.opensearch.ad.ratelimit.CheckPointMaintainRequestAdapter; +import org.opensearch.ad.ratelimit.CheckpointMaintainWorker; +import org.opensearch.ad.ratelimit.CheckpointReadWorker; +import org.opensearch.ad.ratelimit.CheckpointWriteWorker; +import org.opensearch.ad.ratelimit.ColdEntityWorker; +import org.opensearch.ad.ratelimit.EntityColdStartWorker; +import org.opensearch.ad.ratelimit.ResultWriteWorker; +import org.opensearch.ad.rest.RestAnomalyDetectorJobAction; +import org.opensearch.ad.rest.RestDeleteAnomalyDetectorAction; +import org.opensearch.ad.rest.RestDeleteAnomalyResultsAction; +import org.opensearch.ad.rest.RestExecuteAnomalyDetectorAction; +import org.opensearch.ad.rest.RestGetAnomalyDetectorAction; +import org.opensearch.ad.rest.RestIndexAnomalyDetectorAction; +import org.opensearch.ad.rest.RestPreviewAnomalyDetectorAction; +import org.opensearch.ad.rest.RestSearchADTasksAction; +import org.opensearch.ad.rest.RestSearchAnomalyDetectorAction; +import org.opensearch.ad.rest.RestSearchAnomalyDetectorInfoAction; +import org.opensearch.ad.rest.RestSearchAnomalyResultAction; +import org.opensearch.ad.rest.RestSearchTopAnomalyResultAction; +import org.opensearch.ad.rest.RestStatsAnomalyDetectorAction; +import org.opensearch.ad.rest.RestValidateAnomalyDetectorAction; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.ad.settings.ADNumericSetting; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.settings.LegacyOpenDistroAnomalyDetectorSettings; +import org.opensearch.ad.stats.ADStat; +import org.opensearch.ad.stats.ADStats; +import org.opensearch.ad.stats.suppliers.CounterSupplier; +import org.opensearch.ad.stats.suppliers.IndexStatusSupplier; +import org.opensearch.ad.stats.suppliers.ModelsOnNodeCountSupplier; +import org.opensearch.ad.stats.suppliers.ModelsOnNodeSupplier; +import org.opensearch.ad.stats.suppliers.SettableSupplier; +import org.opensearch.ad.task.ADBatchTaskRunner; +import org.opensearch.ad.task.ADTaskCacheManager; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.ADBatchAnomalyResultAction; +import org.opensearch.ad.transport.ADBatchAnomalyResultTransportAction; +import org.opensearch.ad.transport.ADBatchTaskRemoteExecutionAction; +import org.opensearch.ad.transport.ADBatchTaskRemoteExecutionTransportAction; +import org.opensearch.ad.transport.ADCancelTaskAction; +import org.opensearch.ad.transport.ADCancelTaskTransportAction; +import org.opensearch.ad.transport.ADResultBulkAction; +import org.opensearch.ad.transport.ADResultBulkTransportAction; +import org.opensearch.ad.transport.ADStatsNodesAction; +import org.opensearch.ad.transport.ADStatsNodesTransportAction; +import org.opensearch.ad.transport.ADTaskProfileAction; +import org.opensearch.ad.transport.ADTaskProfileTransportAction; +import org.opensearch.ad.transport.AnomalyDetectorJobAction; +import org.opensearch.ad.transport.AnomalyDetectorJobTransportAction; +import org.opensearch.ad.transport.AnomalyResultAction; +import org.opensearch.ad.transport.AnomalyResultTransportAction; +import org.opensearch.ad.transport.CronAction; +import org.opensearch.ad.transport.CronTransportAction; +import org.opensearch.ad.transport.DeleteAnomalyDetectorAction; +import org.opensearch.ad.transport.DeleteAnomalyDetectorTransportAction; +import org.opensearch.ad.transport.DeleteAnomalyResultsAction; +import org.opensearch.ad.transport.DeleteAnomalyResultsTransportAction; +import org.opensearch.ad.transport.DeleteModelAction; +import org.opensearch.ad.transport.DeleteModelTransportAction; +import org.opensearch.ad.transport.EntityProfileAction; +import org.opensearch.ad.transport.EntityProfileTransportAction; +import org.opensearch.ad.transport.EntityResultAction; +import org.opensearch.ad.transport.EntityResultTransportAction; +import org.opensearch.ad.transport.ForwardADTaskAction; +import org.opensearch.ad.transport.ForwardADTaskTransportAction; +import org.opensearch.ad.transport.GetAnomalyDetectorAction; +import org.opensearch.ad.transport.GetAnomalyDetectorTransportAction; +import org.opensearch.ad.transport.IndexAnomalyDetectorAction; +import org.opensearch.ad.transport.IndexAnomalyDetectorTransportAction; +import org.opensearch.ad.transport.PreviewAnomalyDetectorAction; +import org.opensearch.ad.transport.PreviewAnomalyDetectorTransportAction; +import org.opensearch.ad.transport.ProfileAction; +import org.opensearch.ad.transport.ProfileTransportAction; +import org.opensearch.ad.transport.RCFPollingAction; +import org.opensearch.ad.transport.RCFPollingTransportAction; +import org.opensearch.ad.transport.RCFResultAction; +import org.opensearch.ad.transport.RCFResultTransportAction; +import org.opensearch.ad.transport.SearchADTasksAction; +import org.opensearch.ad.transport.SearchADTasksTransportAction; +import org.opensearch.ad.transport.SearchAnomalyDetectorAction; +import org.opensearch.ad.transport.SearchAnomalyDetectorInfoAction; +import org.opensearch.ad.transport.SearchAnomalyDetectorInfoTransportAction; +import org.opensearch.ad.transport.SearchAnomalyDetectorTransportAction; +import org.opensearch.ad.transport.SearchAnomalyResultAction; +import org.opensearch.ad.transport.SearchAnomalyResultTransportAction; +import org.opensearch.ad.transport.SearchTopAnomalyResultAction; +import org.opensearch.ad.transport.SearchTopAnomalyResultTransportAction; +import org.opensearch.ad.transport.StatsAnomalyDetectorAction; +import org.opensearch.ad.transport.StatsAnomalyDetectorTransportAction; +import org.opensearch.ad.transport.StopDetectorAction; +import org.opensearch.ad.transport.StopDetectorTransportAction; +import org.opensearch.ad.transport.ThresholdResultAction; +import org.opensearch.ad.transport.ThresholdResultTransportAction; +import org.opensearch.ad.transport.ValidateAnomalyDetectorAction; +import org.opensearch.ad.transport.ValidateAnomalyDetectorTransportAction; +import org.opensearch.ad.transport.handler.ADSearchHandler; +import org.opensearch.ad.transport.handler.AnomalyIndexHandler; +import org.opensearch.ad.transport.handler.AnomalyResultBulkIndexHandler; +import org.opensearch.ad.transport.handler.MultiEntityResultHandler; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.ad.util.IndexUtils; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.ad.util.Throttler; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.IndexScopedSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.settings.SettingsFilter; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; +import org.opensearch.core.xcontent.XContentParserUtils; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.env.Environment; +import org.opensearch.env.NodeEnvironment; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.jobscheduler.spi.JobSchedulerExtension; +import org.opensearch.jobscheduler.spi.ScheduledJobParser; +import org.opensearch.jobscheduler.spi.ScheduledJobRunner; +import org.opensearch.monitor.jvm.JvmInfo; +import org.opensearch.monitor.jvm.JvmService; +import org.opensearch.plugins.ActionPlugin; +import org.opensearch.plugins.Plugin; +import org.opensearch.plugins.ScriptPlugin; +import org.opensearch.repositories.RepositoriesService; +import org.opensearch.rest.RestController; +import org.opensearch.rest.RestHandler; +import org.opensearch.script.ScriptService; +import org.opensearch.threadpool.ExecutorBuilder; +import org.opensearch.threadpool.ScalingExecutorBuilder; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.dataprocessor.Imputer; +import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; +import org.opensearch.timeseries.function.ThrowingSupplierWrapper; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.watcher.ResourceWatcherService; + +import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper; +import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState; +import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV3StateConverter; +import com.amazon.randomcutforest.state.RandomCutForestMapper; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; + +import io.protostuff.LinkedBuffer; +import io.protostuff.Schema; +import io.protostuff.runtime.RuntimeSchema; + +/** + * Entry point of AD plugin. + */ +public class TimeSeriesAnalyticsPlugin extends Plugin implements ActionPlugin, ScriptPlugin, JobSchedulerExtension { + + private static final Logger LOG = LogManager.getLogger(TimeSeriesAnalyticsPlugin.class); + + // AD constants + public static final String LEGACY_AD_BASE = "/_opendistro/_anomaly_detection"; + public static final String LEGACY_OPENDISTRO_AD_BASE_URI = LEGACY_AD_BASE + "/detectors"; + public static final String AD_BASE_URI = "/_plugins/_anomaly_detection"; + public static final String AD_BASE_DETECTORS_URI = AD_BASE_URI + "/detectors"; + public static final String AD_THREAD_POOL_PREFIX = "opensearch.ad."; + public static final String AD_THREAD_POOL_NAME = "ad-threadpool"; + public static final String AD_BATCH_TASK_THREAD_POOL_NAME = "ad-batch-task-threadpool"; + + // forecasting constants + public static final String FORECAST_BASE_URI = "/_plugins/_forecast"; + public static final String FORECAST_FORECASTERS_URI = FORECAST_BASE_URI + "/forecasters"; + public static final String FORECAST_THREAD_POOL_PREFIX = "opensearch.forecast."; + public static final String FORECAST_THREAD_POOL_NAME = "forecast-threadpool"; + public static final String FORECAST_BATCH_TASK_THREAD_POOL_NAME = "forecast-batch-task-threadpool"; + + public static final String TIME_SERIES_JOB_TYPE = "opensearch_time_series_analytics"; + + private static Gson gson; + private ADIndexManagement anomalyDetectionIndices; + private AnomalyDetectorRunner anomalyDetectorRunner; + private Client client; + private ClusterService clusterService; + private ThreadPool threadPool; + private ADStats adStats; + private ClientUtil clientUtil; + private SecurityClientUtil securityClientUtil; + private DiscoveryNodeFilterer nodeFilter; + private IndexUtils indexUtils; + private ADTaskManager adTaskManager; + private ADBatchTaskRunner adBatchTaskRunner; + // package private for testing + GenericObjectPool serializeRCFBufferPool; + private NodeStateManager stateManager; + private ExecuteADResultResponseRecorder adResultResponseRecorder; + + static { + SpecialPermission.check(); + // gson intialization requires "java.lang.RuntimePermission" "accessDeclaredMembers" to + // initialize ConstructorConstructor + AccessController.doPrivileged((PrivilegedAction) TimeSeriesAnalyticsPlugin::initGson); + } + + public TimeSeriesAnalyticsPlugin() {} + + @Override + public List getRestHandlers( + Settings settings, + RestController restController, + ClusterSettings clusterSettings, + IndexScopedSettings indexScopedSettings, + SettingsFilter settingsFilter, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier nodesInCluster + ) { + AnomalyDetectorJobRunner jobRunner = AnomalyDetectorJobRunner.getJobRunnerInstance(); + jobRunner.setClient(client); + jobRunner.setThreadPool(threadPool); + jobRunner.setSettings(settings); + jobRunner.setAnomalyDetectionIndices(anomalyDetectionIndices); + jobRunner.setAdTaskManager(adTaskManager); + jobRunner.setNodeStateManager(stateManager); + jobRunner.setExecuteADResultResponseRecorder(adResultResponseRecorder); + + RestGetAnomalyDetectorAction restGetAnomalyDetectorAction = new RestGetAnomalyDetectorAction(); + RestIndexAnomalyDetectorAction restIndexAnomalyDetectorAction = new RestIndexAnomalyDetectorAction(settings, clusterService); + RestSearchAnomalyDetectorAction searchAnomalyDetectorAction = new RestSearchAnomalyDetectorAction(); + RestSearchAnomalyResultAction searchAnomalyResultAction = new RestSearchAnomalyResultAction(); + RestSearchADTasksAction searchADTasksAction = new RestSearchADTasksAction(); + RestDeleteAnomalyDetectorAction deleteAnomalyDetectorAction = new RestDeleteAnomalyDetectorAction(); + RestExecuteAnomalyDetectorAction executeAnomalyDetectorAction = new RestExecuteAnomalyDetectorAction(settings, clusterService); + RestStatsAnomalyDetectorAction statsAnomalyDetectorAction = new RestStatsAnomalyDetectorAction(adStats, this.nodeFilter); + RestAnomalyDetectorJobAction anomalyDetectorJobAction = new RestAnomalyDetectorJobAction(settings, clusterService); + RestSearchAnomalyDetectorInfoAction searchAnomalyDetectorInfoAction = new RestSearchAnomalyDetectorInfoAction(); + RestPreviewAnomalyDetectorAction previewAnomalyDetectorAction = new RestPreviewAnomalyDetectorAction(); + RestDeleteAnomalyResultsAction deleteAnomalyResultsAction = new RestDeleteAnomalyResultsAction(); + RestSearchTopAnomalyResultAction searchTopAnomalyResultAction = new RestSearchTopAnomalyResultAction(); + RestValidateAnomalyDetectorAction validateAnomalyDetectorAction = new RestValidateAnomalyDetectorAction(settings, clusterService); + + return ImmutableList + .of( + restGetAnomalyDetectorAction, + restIndexAnomalyDetectorAction, + searchAnomalyDetectorAction, + searchAnomalyResultAction, + searchADTasksAction, + deleteAnomalyDetectorAction, + executeAnomalyDetectorAction, + anomalyDetectorJobAction, + statsAnomalyDetectorAction, + searchAnomalyDetectorInfoAction, + previewAnomalyDetectorAction, + deleteAnomalyResultsAction, + searchTopAnomalyResultAction, + validateAnomalyDetectorAction + ); + } + + private static Void initGson() { + gson = new GsonBuilder().serializeSpecialFloatingPointValues().create(); + return null; + } + + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier + ) { + ADEnabledSetting.getInstance().init(clusterService); + ADNumericSetting.getInstance().init(clusterService); + this.client = client; + this.threadPool = threadPool; + Settings settings = environment.settings(); + Throttler throttler = new Throttler(getClock()); + this.clientUtil = new ClientUtil(settings, client, throttler, threadPool); + this.indexUtils = new IndexUtils(client, clientUtil, clusterService, indexNameExpressionResolver); + this.nodeFilter = new DiscoveryNodeFilterer(clusterService); + // convert from checked IOException to unchecked RuntimeException + this.anomalyDetectionIndices = ThrowingSupplierWrapper + .throwingSupplierWrapper( + () -> new ADIndexManagement( + client, + clusterService, + threadPool, + settings, + nodeFilter, + TimeSeriesSettings.MAX_UPDATE_RETRY_TIMES + ) + ) + .get(); + this.clusterService = clusterService; + + Imputer imputer = new LinearUniformImputer(true); + stateManager = new NodeStateManager( + client, + xContentRegistry, + settings, + clientUtil, + getClock(), + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + clusterService + ); + securityClientUtil = new SecurityClientUtil(stateManager, settings); + SearchFeatureDao searchFeatureDao = new SearchFeatureDao( + client, + xContentRegistry, + imputer, + securityClientUtil, + settings, + clusterService, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE + ); + + JvmService jvmService = new JvmService(environment.settings()); + RandomCutForestMapper mapper = new RandomCutForestMapper(); + mapper.setSaveExecutorContextEnabled(true); + mapper.setSaveTreeStateEnabled(true); + mapper.setPartialTreeStateEnabled(true); + V1JsonToV3StateConverter converter = new V1JsonToV3StateConverter(); + + double modelMaxSizePercent = AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE.get(settings); + + ADCircuitBreakerService adCircuitBreakerService = new ADCircuitBreakerService(jvmService).init(); + + MemoryTracker memoryTracker = new MemoryTracker( + jvmService, + modelMaxSizePercent, + AnomalyDetectorSettings.DESIRED_MODEL_SIZE_PERCENTAGE, + clusterService, + adCircuitBreakerService + ); + + FeatureManager featureManager = new FeatureManager( + searchFeatureDao, + imputer, + getClock(), + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, + AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, + AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, + AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, + AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, + AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + threadPool, + AD_THREAD_POOL_NAME + ); + + long heapSizeBytes = JvmInfo.jvmInfo().getMem().getHeapMax().getBytes(); + + serializeRCFBufferPool = AccessController.doPrivileged(new PrivilegedAction>() { + @Override + public GenericObjectPool run() { + return new GenericObjectPool<>(new BasePooledObjectFactory() { + @Override + public LinkedBuffer create() throws Exception { + return LinkedBuffer.allocate(AnomalyDetectorSettings.SERIALIZATION_BUFFER_BYTES); + } + + @Override + public PooledObject wrap(LinkedBuffer obj) { + return new DefaultPooledObject<>(obj); + } + }); + } + }); + serializeRCFBufferPool.setMaxTotal(AnomalyDetectorSettings.MAX_TOTAL_RCF_SERIALIZATION_BUFFERS); + serializeRCFBufferPool.setMaxIdle(AnomalyDetectorSettings.MAX_TOTAL_RCF_SERIALIZATION_BUFFERS); + serializeRCFBufferPool.setMinIdle(0); + serializeRCFBufferPool.setBlockWhenExhausted(false); + serializeRCFBufferPool.setTimeBetweenEvictionRuns(AnomalyDetectorSettings.HOURLY_MAINTENANCE); + + CheckpointDao checkpoint = new CheckpointDao( + client, + clientUtil, + ADCommonName.CHECKPOINT_INDEX_NAME, + gson, + mapper, + converter, + new ThresholdedRandomCutForestMapper(), + AccessController + .doPrivileged( + (PrivilegedAction>) () -> RuntimeSchema + .getSchema(ThresholdedRandomCutForestState.class) + ), + HybridThresholdingModel.class, + anomalyDetectionIndices, + AnomalyDetectorSettings.MAX_CHECKPOINT_BYTES, + serializeRCFBufferPool, + AnomalyDetectorSettings.SERIALIZATION_BUFFER_BYTES, + 1 - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE + ); + + Random random = new Random(42); + + CacheProvider cacheProvider = new CacheProvider(); + + CheckPointMaintainRequestAdapter adapter = new CheckPointMaintainRequestAdapter( + cacheProvider, + checkpoint, + ADCommonName.CHECKPOINT_INDEX_NAME, + AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, + getClock(), + clusterService, + settings + ); + + CheckpointWriteWorker checkpointWriteQueue = new CheckpointWriteWorker( + heapSizeBytes, + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES, + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + checkpoint, + ADCommonName.CHECKPOINT_INDEX_NAME, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + stateManager, + AnomalyDetectorSettings.HOURLY_MAINTENANCE + ); + + CheckpointMaintainWorker checkpointMaintainQueue = new CheckpointMaintainWorker( + heapSizeBytes, + AnomalyDetectorSettings.CHECKPOINT_MAINTAIN_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + checkpointWriteQueue, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + stateManager, + adapter + ); + + EntityCache cache = new PriorityCache( + checkpoint, + AnomalyDetectorSettings.DEDICATED_CACHE_SIZE.get(settings), + AnomalyDetectorSettings.CHECKPOINT_TTL, + AnomalyDetectorSettings.MAX_INACTIVE_ENTITIES, + memoryTracker, + AnomalyDetectorSettings.NUM_TREES, + getClock(), + clusterService, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + threadPool, + checkpointWriteQueue, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + checkpointMaintainQueue, + settings, + AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ + ); + + cacheProvider.set(cache); + + EntityColdStarter entityColdStarter = new EntityColdStarter( + getClock(), + threadPool, + stateManager, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.TIME_DECAY, + AnomalyDetectorSettings.NUM_MIN_SAMPLES, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + imputer, + searchFeatureDao, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + featureManager, + settings, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + AnomalyDetectorSettings.MAX_COLD_START_ROUNDS + ); + + EntityColdStartWorker coldstartQueue = new EntityColdStartWorker( + heapSizeBytes, + AnomalyDetectorSettings.ENTITY_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + entityColdStarter, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + stateManager, + cacheProvider + ); + + ModelManager modelManager = new ModelManager( + checkpoint, + getClock(), + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.TIME_DECAY, + AnomalyDetectorSettings.NUM_MIN_SAMPLES, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + AnomalyDetectorSettings.MIN_PREVIEW_SIZE, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, + entityColdStarter, + featureManager, + memoryTracker, + settings, + clusterService + ); + + MultiEntityResultHandler multiEntityResultHandler = new MultiEntityResultHandler( + client, + settings, + threadPool, + anomalyDetectionIndices, + this.clientUtil, + this.indexUtils, + clusterService + ); + + ResultWriteWorker resultWriteQueue = new ResultWriteWorker( + heapSizeBytes, + AnomalyDetectorSettings.RESULT_WRITE_QUEUE_SIZE_IN_BYTES, + AnomalyDetectorSettings.RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + multiEntityResultHandler, + xContentRegistry, + stateManager, + AnomalyDetectorSettings.HOURLY_MAINTENANCE + ); + + Map> stats = ImmutableMap + .>builder() + .put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) + .put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) + .put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) + .put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) + .put( + StatNames.MODEL_INFORMATION.getName(), + new ADStat<>(false, new ModelsOnNodeSupplier(modelManager, cacheProvider, settings, clusterService)) + ) + .put( + StatNames.ANOMALY_DETECTORS_INDEX_STATUS.getName(), + new ADStat<>(true, new IndexStatusSupplier(indexUtils, CommonName.CONFIG_INDEX)) + ) + .put( + StatNames.ANOMALY_RESULTS_INDEX_STATUS.getName(), + new ADStat<>(true, new IndexStatusSupplier(indexUtils, ADCommonName.ANOMALY_RESULT_INDEX_ALIAS)) + ) + .put( + StatNames.MODELS_CHECKPOINT_INDEX_STATUS.getName(), + new ADStat<>(true, new IndexStatusSupplier(indexUtils, ADCommonName.CHECKPOINT_INDEX_NAME)) + ) + .put( + StatNames.ANOMALY_DETECTION_JOB_INDEX_STATUS.getName(), + new ADStat<>(true, new IndexStatusSupplier(indexUtils, CommonName.JOB_INDEX)) + ) + .put( + StatNames.ANOMALY_DETECTION_STATE_STATUS.getName(), + new ADStat<>(true, new IndexStatusSupplier(indexUtils, ADCommonName.DETECTION_STATE_INDEX)) + ) + .put(StatNames.DETECTOR_COUNT.getName(), new ADStat<>(true, new SettableSupplier())) + .put(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName(), new ADStat<>(true, new SettableSupplier())) + .put(StatNames.MULTI_ENTITY_DETECTOR_COUNT.getName(), new ADStat<>(true, new SettableSupplier())) + .put(StatNames.AD_EXECUTING_BATCH_TASK_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) + .put(StatNames.AD_CANCELED_BATCH_TASK_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) + .put(StatNames.AD_TOTAL_BATCH_TASK_EXECUTION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) + .put(StatNames.AD_BATCH_TASK_FAILURE_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) + .put(StatNames.MODEL_COUNT.getName(), new ADStat<>(false, new ModelsOnNodeCountSupplier(modelManager, cacheProvider))) + .put(StatNames.MODEL_CORRUTPION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) + .build(); + + adStats = new ADStats(stats); + + CheckpointReadWorker checkpointReadQueue = new CheckpointReadWorker( + heapSizeBytes, + AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + modelManager, + checkpoint, + coldstartQueue, + resultWriteQueue, + stateManager, + anomalyDetectionIndices, + cacheProvider, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + adStats + ); + + ColdEntityWorker coldEntityQueue = new ColdEntityWorker( + heapSizeBytes, + AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + checkpointReadQueue, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + stateManager + ); + + ADDataMigrator dataMigrator = new ADDataMigrator(client, clusterService, xContentRegistry, anomalyDetectionIndices); + HashRing hashRing = new HashRing(nodeFilter, getClock(), settings, client, clusterService, dataMigrator, modelManager); + + anomalyDetectorRunner = new AnomalyDetectorRunner(modelManager, featureManager, AnomalyDetectorSettings.MAX_PREVIEW_RESULTS); + + ADTaskCacheManager adTaskCacheManager = new ADTaskCacheManager(settings, clusterService, memoryTracker); + adTaskManager = new ADTaskManager( + settings, + clusterService, + client, + xContentRegistry, + anomalyDetectionIndices, + nodeFilter, + hashRing, + adTaskCacheManager, + threadPool + ); + AnomalyResultBulkIndexHandler anomalyResultBulkIndexHandler = new AnomalyResultBulkIndexHandler( + client, + settings, + threadPool, + this.clientUtil, + this.indexUtils, + clusterService, + anomalyDetectionIndices + ); + adBatchTaskRunner = new ADBatchTaskRunner( + settings, + threadPool, + clusterService, + client, + securityClientUtil, + adCircuitBreakerService, + featureManager, + adTaskManager, + anomalyDetectionIndices, + adStats, + anomalyResultBulkIndexHandler, + adTaskCacheManager, + searchFeatureDao, + hashRing, + modelManager + ); + + ADSearchHandler adSearchHandler = new ADSearchHandler(settings, clusterService, client); + + AnomalyIndexHandler anomalyResultHandler = new AnomalyIndexHandler( + client, + settings, + threadPool, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + this.clientUtil, + this.indexUtils, + clusterService + ); + + adResultResponseRecorder = new ExecuteADResultResponseRecorder( + anomalyDetectionIndices, + anomalyResultHandler, + adTaskManager, + nodeFilter, + threadPool, + client, + stateManager, + adTaskCacheManager, + AnomalyDetectorSettings.NUM_MIN_SAMPLES + ); + + // return objects used by Guice to inject dependencies for e.g., + // transport action handler constructors + return ImmutableList + .of( + anomalyDetectionIndices, + anomalyDetectorRunner, + searchFeatureDao, + imputer, + gson, + jvmService, + hashRing, + featureManager, + modelManager, + stateManager, + new ADClusterEventListener(clusterService, hashRing), + adCircuitBreakerService, + adStats, + new ClusterManagerEventListener( + clusterService, + threadPool, + client, + getClock(), + clientUtil, + nodeFilter, + AnomalyDetectorSettings.CHECKPOINT_TTL, + settings + ), + nodeFilter, + multiEntityResultHandler, + checkpoint, + cacheProvider, + adTaskManager, + adBatchTaskRunner, + adSearchHandler, + coldstartQueue, + resultWriteQueue, + checkpointReadQueue, + checkpointWriteQueue, + coldEntityQueue, + entityColdStarter, + adTaskCacheManager, + adResultResponseRecorder + ); + } + + /** + * createComponents doesn't work for Clock as ES process cannot start + * complaining it cannot find Clock instances for transport actions constructors. + * @return a UTC clock + */ + protected Clock getClock() { + return Clock.systemUTC(); + } + + @Override + public List> getExecutorBuilders(Settings settings) { + return ImmutableList + .of( + new ScalingExecutorBuilder( + AD_THREAD_POOL_NAME, + 1, + // HCAD can be heavy after supporting 1 million entities. + // Limit to use at most half of the processors. + Math.max(1, OpenSearchExecutors.allocatedProcessors(settings) / 2), + TimeValue.timeValueMinutes(10), + AD_THREAD_POOL_PREFIX + AD_THREAD_POOL_NAME + ), + new ScalingExecutorBuilder( + AD_BATCH_TASK_THREAD_POOL_NAME, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(settings) / 8), + TimeValue.timeValueMinutes(10), + AD_THREAD_POOL_PREFIX + AD_BATCH_TASK_THREAD_POOL_NAME + ) + ); + } + + @Override + public List> getSettings() { + List> enabledSetting = ADEnabledSetting.getInstance().getSettings(); + List> numericSetting = ADNumericSetting.getInstance().getSettings(); + + List> systemSetting = ImmutableList + .of( + // ====================================== + // AD settings + // ====================================== + // HCAD cache + LegacyOpenDistroAnomalyDetectorSettings.MAX_CACHE_MISS_HANDLING_PER_SECOND, + AnomalyDetectorSettings.DEDICATED_CACHE_SIZE, + // Detector config + LegacyOpenDistroAnomalyDetectorSettings.DETECTION_INTERVAL, + LegacyOpenDistroAnomalyDetectorSettings.DETECTION_WINDOW_DELAY, + LegacyOpenDistroAnomalyDetectorSettings.MAX_ANOMALY_FEATURES, + AnomalyDetectorSettings.DETECTION_INTERVAL, + AnomalyDetectorSettings.DETECTION_WINDOW_DELAY, + AnomalyDetectorSettings.MAX_ANOMALY_FEATURES, + // Fault tolerance + LegacyOpenDistroAnomalyDetectorSettings.REQUEST_TIMEOUT, + LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE, + LegacyOpenDistroAnomalyDetectorSettings.COOLDOWN_MINUTES, + LegacyOpenDistroAnomalyDetectorSettings.BACKOFF_MINUTES, + LegacyOpenDistroAnomalyDetectorSettings.BACKOFF_INITIAL_DELAY, + LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_BACKOFF, + AnomalyDetectorSettings.REQUEST_TIMEOUT, + AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE, + AnomalyDetectorSettings.COOLDOWN_MINUTES, + AnomalyDetectorSettings.BACKOFF_MINUTES, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF, + // result index rollover + LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD, + LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS, + LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD, + AnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD, + AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD, + AnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD, + // resource usage control + LegacyOpenDistroAnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE, + LegacyOpenDistroAnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS, + LegacyOpenDistroAnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS, + LegacyOpenDistroAnomalyDetectorSettings.INDEX_PRESSURE_SOFT_LIMIT, + LegacyOpenDistroAnomalyDetectorSettings.MAX_PRIMARY_SHARDS, + AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE, + AnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS, + AnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS, + AnomalyDetectorSettings.AD_INDEX_PRESSURE_SOFT_LIMIT, + AnomalyDetectorSettings.AD_INDEX_PRESSURE_HARD_LIMIT, + AnomalyDetectorSettings.AD_MAX_PRIMARY_SHARDS, + // Security + LegacyOpenDistroAnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES, + AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES, + // Historical + LegacyOpenDistroAnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE, + LegacyOpenDistroAnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS, + LegacyOpenDistroAnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR, + LegacyOpenDistroAnomalyDetectorSettings.BATCH_TASK_PIECE_SIZE, + AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE, + AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS, + AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR, + AnomalyDetectorSettings.BATCH_TASK_PIECE_SIZE, + AnomalyDetectorSettings.MAX_TOP_ENTITIES_FOR_HISTORICAL_ANALYSIS, + AnomalyDetectorSettings.MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS, + AnomalyDetectorSettings.MAX_CACHED_DELETED_TASKS, + // rate limiting + AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE, + AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, + AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_BATCH_SIZE, + AnomalyDetectorSettings.COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, + AnomalyDetectorSettings.AD_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS, + AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, + AnomalyDetectorSettings.CHECKPOINT_TTL, + // query limit + LegacyOpenDistroAnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, + LegacyOpenDistroAnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW, + AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, + AnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW, + AnomalyDetectorSettings.MAX_CONCURRENT_PREVIEW, + AnomalyDetectorSettings.PAGE_SIZE, + // clean resource + AnomalyDetectorSettings.DELETE_AD_RESULT_WHEN_DELETE_DETECTOR, + // stats/profile API + AnomalyDetectorSettings.MAX_MODEL_SIZE_PER_NODE, + // ====================================== + // Forecast settings + // ====================================== + // result index rollover + ForecastSettings.FORECAST_RESULT_HISTORY_MAX_DOCS_PER_SHARD, + ForecastSettings.FORECAST_RESULT_HISTORY_RETENTION_PERIOD, + ForecastSettings.FORECAST_RESULT_HISTORY_ROLLOVER_PERIOD, + // resource usage control + ForecastSettings.FORECAST_MODEL_MAX_SIZE_PERCENTAGE, + // TODO: add validation code + // ForecastSettings.FORECAST_MAX_SINGLE_STREAM_FORECASTERS, + // ForecastSettings.FORECAST_MAX_HC_FORECASTERS, + ForecastSettings.FORECAST_INDEX_PRESSURE_SOFT_LIMIT, + ForecastSettings.FORECAST_INDEX_PRESSURE_HARD_LIMIT, + ForecastSettings.FORECAST_MAX_PRIMARY_SHARDS + ); + return unmodifiableList( + Stream + .of(enabledSetting.stream(), systemSetting.stream(), numericSetting.stream()) + .reduce(Stream::concat) + .orElseGet(Stream::empty) + .collect(Collectors.toList()) + ); + } + + @Override + public List getNamedXContent() { + return ImmutableList + .of( + AnomalyDetector.XCONTENT_REGISTRY, + AnomalyResult.XCONTENT_REGISTRY, + DetectorInternalState.XCONTENT_REGISTRY, + AnomalyDetectorJob.XCONTENT_REGISTRY, + Forecaster.XCONTENT_REGISTRY + ); + } + + /* + * Register action and handler so that transportClient can find proxy for action + */ + @Override + public List> getActions() { + return Arrays + .asList( + new ActionHandler<>(DeleteModelAction.INSTANCE, DeleteModelTransportAction.class), + new ActionHandler<>(StopDetectorAction.INSTANCE, StopDetectorTransportAction.class), + new ActionHandler<>(RCFResultAction.INSTANCE, RCFResultTransportAction.class), + new ActionHandler<>(ThresholdResultAction.INSTANCE, ThresholdResultTransportAction.class), + new ActionHandler<>(AnomalyResultAction.INSTANCE, AnomalyResultTransportAction.class), + new ActionHandler<>(CronAction.INSTANCE, CronTransportAction.class), + new ActionHandler<>(ADStatsNodesAction.INSTANCE, ADStatsNodesTransportAction.class), + new ActionHandler<>(ProfileAction.INSTANCE, ProfileTransportAction.class), + new ActionHandler<>(RCFPollingAction.INSTANCE, RCFPollingTransportAction.class), + new ActionHandler<>(SearchAnomalyDetectorAction.INSTANCE, SearchAnomalyDetectorTransportAction.class), + new ActionHandler<>(SearchAnomalyResultAction.INSTANCE, SearchAnomalyResultTransportAction.class), + new ActionHandler<>(SearchADTasksAction.INSTANCE, SearchADTasksTransportAction.class), + new ActionHandler<>(StatsAnomalyDetectorAction.INSTANCE, StatsAnomalyDetectorTransportAction.class), + new ActionHandler<>(DeleteAnomalyDetectorAction.INSTANCE, DeleteAnomalyDetectorTransportAction.class), + new ActionHandler<>(GetAnomalyDetectorAction.INSTANCE, GetAnomalyDetectorTransportAction.class), + new ActionHandler<>(IndexAnomalyDetectorAction.INSTANCE, IndexAnomalyDetectorTransportAction.class), + new ActionHandler<>(AnomalyDetectorJobAction.INSTANCE, AnomalyDetectorJobTransportAction.class), + new ActionHandler<>(ADResultBulkAction.INSTANCE, ADResultBulkTransportAction.class), + new ActionHandler<>(EntityResultAction.INSTANCE, EntityResultTransportAction.class), + new ActionHandler<>(EntityProfileAction.INSTANCE, EntityProfileTransportAction.class), + new ActionHandler<>(SearchAnomalyDetectorInfoAction.INSTANCE, SearchAnomalyDetectorInfoTransportAction.class), + new ActionHandler<>(PreviewAnomalyDetectorAction.INSTANCE, PreviewAnomalyDetectorTransportAction.class), + new ActionHandler<>(ADBatchAnomalyResultAction.INSTANCE, ADBatchAnomalyResultTransportAction.class), + new ActionHandler<>(ADBatchTaskRemoteExecutionAction.INSTANCE, ADBatchTaskRemoteExecutionTransportAction.class), + new ActionHandler<>(ADTaskProfileAction.INSTANCE, ADTaskProfileTransportAction.class), + new ActionHandler<>(ADCancelTaskAction.INSTANCE, ADCancelTaskTransportAction.class), + new ActionHandler<>(ForwardADTaskAction.INSTANCE, ForwardADTaskTransportAction.class), + new ActionHandler<>(DeleteAnomalyResultsAction.INSTANCE, DeleteAnomalyResultsTransportAction.class), + new ActionHandler<>(SearchTopAnomalyResultAction.INSTANCE, SearchTopAnomalyResultTransportAction.class), + new ActionHandler<>(ValidateAnomalyDetectorAction.INSTANCE, ValidateAnomalyDetectorTransportAction.class) + ); + } + + @Override + public String getJobType() { + return TIME_SERIES_JOB_TYPE; + } + + @Override + public String getJobIndex() { + return CommonName.JOB_INDEX; + } + + @Override + public ScheduledJobRunner getJobRunner() { + return AnomalyDetectorJobRunner.getJobRunnerInstance(); + } + + @Override + public ScheduledJobParser getJobParser() { + return (parser, id, jobDocVersion) -> { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + return AnomalyDetectorJob.parse(parser); + }; + } + + @Override + public void close() { + if (serializeRCFBufferPool != null) { + try { + AccessController.doPrivileged((PrivilegedAction) () -> { + serializeRCFBufferPool.clear(); + serializeRCFBufferPool.close(); + return null; + }); + serializeRCFBufferPool = null; + } catch (Exception e) { + LOG.error("Failed to shut down object Pool", e); + } + } + } +} diff --git a/src/main/java/org/opensearch/timeseries/annotation/Generated.java-e b/src/main/java/org/opensearch/timeseries/annotation/Generated.java-e new file mode 100644 index 000000000..d3812c9ca --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/annotation/Generated.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.annotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Jacoco will ignore the annotated code. + * Similar to Lombok Generated annotation. Create this similar annotation as we don't involve Lombok. + */ +@Target({ ElementType.CONSTRUCTOR, ElementType.METHOD, ElementType.FIELD, ElementType.TYPE }) +@Retention(RetentionPolicy.CLASS) +public @interface Generated { +} diff --git a/src/main/java/org/opensearch/timeseries/common/exception/ClientException.java-e b/src/main/java/org/opensearch/timeseries/common/exception/ClientException.java-e new file mode 100644 index 000000000..d8be97f37 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/common/exception/ClientException.java-e @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.common.exception; + +/** + * All exception visible to transport layer's client is under ClientException. + */ +public class ClientException extends TimeSeriesException { + + public ClientException(String message) { + super(message); + } + + public ClientException(String configId, String message) { + super(configId, message); + } + + public ClientException(String configId, String message, Throwable throwable) { + super(configId, message, throwable); + } + + public ClientException(String configId, Throwable cause) { + super(configId, cause); + } +} diff --git a/src/main/java/org/opensearch/timeseries/common/exception/DuplicateTaskException.java-e b/src/main/java/org/opensearch/timeseries/common/exception/DuplicateTaskException.java-e new file mode 100644 index 000000000..1791e322d --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/common/exception/DuplicateTaskException.java-e @@ -0,0 +1,20 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.common.exception; + +public class DuplicateTaskException extends TimeSeriesException { + + public DuplicateTaskException(String msg) { + super(msg); + this.countedInStats(false); + } +} diff --git a/src/main/java/org/opensearch/timeseries/common/exception/EndRunException.java-e b/src/main/java/org/opensearch/timeseries/common/exception/EndRunException.java-e new file mode 100644 index 000000000..a4b11c621 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/common/exception/EndRunException.java-e @@ -0,0 +1,44 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.common.exception; + +/** + * Exception for failures that might impact the customer. + * + */ +public class EndRunException extends ClientException { + private boolean endNow; + + public EndRunException(String message, boolean endNow) { + super(message); + this.endNow = endNow; + } + + public EndRunException(String configId, String message, boolean endNow) { + super(configId, message); + this.endNow = endNow; + } + + public EndRunException(String configId, String message, Throwable throwable, boolean endNow) { + super(configId, message, throwable); + this.endNow = endNow; + } + + /** + * @return true for "unrecoverable issue". We want to terminate the detector run immediately. + * false for "maybe unrecoverable issue but worth retrying a few more times." We want + * to wait for a few more times on different requests before terminating the detector run. + */ + public boolean isEndNow() { + return endNow; + } +} diff --git a/src/main/java/org/opensearch/timeseries/common/exception/InternalFailure.java-e b/src/main/java/org/opensearch/timeseries/common/exception/InternalFailure.java-e new file mode 100644 index 000000000..c7c9048cb --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/common/exception/InternalFailure.java-e @@ -0,0 +1,35 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.common.exception; + +/** + * Exception for root cause unknown failure. Maybe transient. Client can continue the detector running. + * + */ +public class InternalFailure extends ClientException { + + public InternalFailure(String configId, String message) { + super(configId, message); + } + + public InternalFailure(String configId, String message, Throwable cause) { + super(configId, message, cause); + } + + public InternalFailure(String configId, Throwable cause) { + super(configId, cause); + } + + public InternalFailure(TimeSeriesException cause) { + super(cause.getConfigId(), cause); + } +} diff --git a/src/main/java/org/opensearch/timeseries/common/exception/LimitExceededException.java-e b/src/main/java/org/opensearch/timeseries/common/exception/LimitExceededException.java-e new file mode 100644 index 000000000..e51a2bc4e --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/common/exception/LimitExceededException.java-e @@ -0,0 +1,60 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.common.exception; + +/** + * This exception is thrown when a user/system limit is exceeded. + */ +public class LimitExceededException extends EndRunException { + + /** + * Constructor with a config ID and an explanation. + * + * @param id ID of the time series analysis for which the limit is exceeded + * @param message explanation for the limit + */ + public LimitExceededException(String id, String message) { + super(id, message, true); + this.countedInStats(false); + } + + /** + * Constructor with error message. + * + * @param message explanation for the limit + */ + public LimitExceededException(String message) { + super(message, true); + } + + /** + * Constructor with error message. + * + * @param message explanation for the limit + * @param endRun end detector run or not + */ + public LimitExceededException(String message, boolean endRun) { + super(null, message, endRun); + } + + /** + * Constructor with a config ID and an explanation, and a flag for stopping. + * + * @param id ID of the time series analysis for which the limit is exceeded + * @param message explanation for the limit + * @param stopNow whether to stop time series analysis immediately + */ + public LimitExceededException(String id, String message, boolean stopNow) { + super(id, message, stopNow); + this.countedInStats(false); + } +} diff --git a/src/main/java/org/opensearch/timeseries/common/exception/NotSerializedExceptionName.java b/src/main/java/org/opensearch/timeseries/common/exception/NotSerializedExceptionName.java index 70e1f5e4d..b85005e01 100644 --- a/src/main/java/org/opensearch/timeseries/common/exception/NotSerializedExceptionName.java +++ b/src/main/java/org/opensearch/timeseries/common/exception/NotSerializedExceptionName.java @@ -18,7 +18,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.common.io.stream.NotSerializableExceptionWrapper; +import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; /** * OpenSearch restricts the kind of exceptions that can be thrown over the wire diff --git a/src/main/java/org/opensearch/timeseries/common/exception/NotSerializedExceptionName.java-e b/src/main/java/org/opensearch/timeseries/common/exception/NotSerializedExceptionName.java-e new file mode 100644 index 000000000..b85005e01 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/common/exception/NotSerializedExceptionName.java-e @@ -0,0 +1,111 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.common.exception; + +import static org.opensearch.OpenSearchException.getExceptionName; + +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; + +/** + * OpenSearch restricts the kind of exceptions that can be thrown over the wire + * (Read OpenSearchException.OpenSearchExceptionHandle https://tinyurl.com/wv6c6t7x). + * Since we cannot add our own exception like ResourceNotFoundException without modifying + * OpenSearch's code, we have to unwrap the NotSerializableExceptionWrapper and + * check its root cause message. + * + */ +public enum NotSerializedExceptionName { + + RESOURCE_NOT_FOUND_EXCEPTION_NAME_UNDERSCORE(getExceptionName(new ResourceNotFoundException("", ""))), + LIMIT_EXCEEDED_EXCEPTION_NAME_UNDERSCORE(getExceptionName(new LimitExceededException("", "", false))), + END_RUN_EXCEPTION_NAME_UNDERSCORE(getExceptionName(new EndRunException("", "", false))), + TIME_SERIES_DETECTION_EXCEPTION_NAME_UNDERSCORE(getExceptionName(new TimeSeriesException("", ""))), + INTERNAL_FAILURE_NAME_UNDERSCORE(getExceptionName(new InternalFailure("", ""))), + CLIENT_EXCEPTION_NAME_UNDERSCORE(getExceptionName(new ClientException("", ""))), + CANCELLATION_EXCEPTION_NAME_UNDERSCORE(getExceptionName(new TaskCancelledException("", ""))), + DUPLICATE_TASK_EXCEPTION_NAME_UNDERSCORE(getExceptionName(new DuplicateTaskException(""))), + VERSION_EXCEPTION_NAME_UNDERSCORE(getExceptionName(new VersionException(""))), + VALIDATION_EXCEPTION_NAME_UNDERSCORE(getExceptionName(new ValidationException("", null, null))); + + private static final Logger LOG = LogManager.getLogger(NotSerializedExceptionName.class); + private final String name; + + NotSerializedExceptionName(String name) { + this.name = name; + } + + public String getName() { + return name; + } + + /** + * Convert from a NotSerializableExceptionWrapper to a TimeSeriesException. + * Since NotSerializableExceptionWrapper does not keep some details we need, we + * initialize the exception with default values. + * @param exception an NotSerializableExceptionWrapper exception. + * @param configID Config Id. + * @return converted TimeSeriesException + */ + public static Optional convertWrappedTimeSeriesException( + NotSerializableExceptionWrapper exception, + String configID + ) { + String exceptionMsg = exception.getMessage().trim(); + + TimeSeriesException convertedException = null; + for (NotSerializedExceptionName timeseriesException : values()) { + if (exceptionMsg.startsWith(timeseriesException.getName())) { + switch (timeseriesException) { + case RESOURCE_NOT_FOUND_EXCEPTION_NAME_UNDERSCORE: + convertedException = new ResourceNotFoundException(configID, exceptionMsg); + break; + case LIMIT_EXCEEDED_EXCEPTION_NAME_UNDERSCORE: + convertedException = new LimitExceededException(configID, exceptionMsg, false); + break; + case END_RUN_EXCEPTION_NAME_UNDERSCORE: + convertedException = new EndRunException(configID, exceptionMsg, false); + break; + case TIME_SERIES_DETECTION_EXCEPTION_NAME_UNDERSCORE: + convertedException = new TimeSeriesException(configID, exceptionMsg); + break; + case INTERNAL_FAILURE_NAME_UNDERSCORE: + convertedException = new InternalFailure(configID, exceptionMsg); + break; + case CLIENT_EXCEPTION_NAME_UNDERSCORE: + convertedException = new ClientException(configID, exceptionMsg); + break; + case CANCELLATION_EXCEPTION_NAME_UNDERSCORE: + convertedException = new TaskCancelledException(exceptionMsg, ""); + break; + case DUPLICATE_TASK_EXCEPTION_NAME_UNDERSCORE: + convertedException = new DuplicateTaskException(exceptionMsg); + break; + case VERSION_EXCEPTION_NAME_UNDERSCORE: + convertedException = new VersionException(exceptionMsg); + break; + case VALIDATION_EXCEPTION_NAME_UNDERSCORE: + convertedException = new ValidationException(exceptionMsg, null, null); + break; + default: + LOG.warn(new ParameterizedMessage("Unexpected exception {}", timeseriesException)); + break; + } + } + } + return Optional.ofNullable(convertedException); + } +} diff --git a/src/main/java/org/opensearch/timeseries/common/exception/ResourceNotFoundException.java-e b/src/main/java/org/opensearch/timeseries/common/exception/ResourceNotFoundException.java-e new file mode 100644 index 000000000..eddbcac99 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/common/exception/ResourceNotFoundException.java-e @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.common.exception; + +/** + * This exception is thrown when a resource is not found. + */ +public class ResourceNotFoundException extends TimeSeriesException { + + /** + * Constructor with a config ID and a message. + * + * @param configId ID of the config related to the resource + * @param message explains which resource is not found + */ + public ResourceNotFoundException(String configId, String message) { + super(configId, message); + countedInStats(false); + } + + public ResourceNotFoundException(String message) { + super(message); + countedInStats(false); + } +} diff --git a/src/main/java/org/opensearch/timeseries/common/exception/TaskCancelledException.java-e b/src/main/java/org/opensearch/timeseries/common/exception/TaskCancelledException.java-e new file mode 100644 index 000000000..ba0c3d600 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/common/exception/TaskCancelledException.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.common.exception; + +public class TaskCancelledException extends TimeSeriesException { + private String cancelledBy; + + public TaskCancelledException(String msg, String user) { + super(msg); + this.cancelledBy = user; + this.countedInStats(false); + } + + public String getCancelledBy() { + return cancelledBy; + } +} diff --git a/src/main/java/org/opensearch/timeseries/common/exception/TimeSeriesException.java-e b/src/main/java/org/opensearch/timeseries/common/exception/TimeSeriesException.java-e new file mode 100644 index 000000000..caa2573a9 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/common/exception/TimeSeriesException.java-e @@ -0,0 +1,90 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.common.exception; + +/** + * Base exception for exceptions thrown. + */ +public class TimeSeriesException extends RuntimeException { + + private String configId; + // countedInStats will be used to tell whether the exception should be + // counted in failure stats. + private boolean countedInStats = true; + + public TimeSeriesException(String message) { + super(message); + } + + /** + * Constructor with a config ID and a message. + * + * @param configId config ID + * @param message message of the exception + */ + public TimeSeriesException(String configId, String message) { + super(message); + this.configId = configId; + } + + public TimeSeriesException(String configID, String message, Throwable cause) { + super(message, cause); + this.configId = configID; + } + + public TimeSeriesException(Throwable cause) { + super(cause); + } + + public TimeSeriesException(String configID, Throwable cause) { + super(cause); + this.configId = configID; + } + + /** + * Returns the ID of the analysis config. + * + * @return config ID + */ + public String getConfigId() { + return this.configId; + } + + /** + * Returns if the exception should be counted in stats. + * + * @return true if should count the exception in stats; otherwise return false + */ + public boolean isCountedInStats() { + return countedInStats; + } + + /** + * Set if the exception should be counted in stats. + * + * @param countInStats count the exception in stats + * @return the exception itself + */ + public TimeSeriesException countedInStats(boolean countInStats) { + this.countedInStats = countInStats; + return this; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(configId); + sb.append(' '); + sb.append(super.toString()); + return sb.toString(); + } +} diff --git a/src/main/java/org/opensearch/timeseries/common/exception/ValidationException.java-e b/src/main/java/org/opensearch/timeseries/common/exception/ValidationException.java-e new file mode 100644 index 000000000..4c18c13fe --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/common/exception/ValidationException.java-e @@ -0,0 +1,84 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.common.exception; + +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; + +public class ValidationException extends TimeSeriesException { + private final ValidationIssueType type; + private final ValidationAspect aspect; + private final IntervalTimeConfiguration intervalSuggestion; + + public ValidationIssueType getType() { + return type; + } + + public ValidationAspect getAspect() { + return aspect; + } + + public IntervalTimeConfiguration getIntervalSuggestion() { + return intervalSuggestion; + } + + public ValidationException(String message, ValidationIssueType type, ValidationAspect aspect) { + this(message, null, type, aspect, null); + } + + public ValidationException( + String message, + ValidationIssueType type, + ValidationAspect aspect, + IntervalTimeConfiguration intervalSuggestion + ) { + this(message, null, type, aspect, intervalSuggestion); + } + + public ValidationException( + String message, + Throwable cause, + ValidationIssueType type, + ValidationAspect aspect, + IntervalTimeConfiguration intervalSuggestion + ) { + super(Config.NO_ID, message, cause); + this.type = type; + this.aspect = aspect; + this.intervalSuggestion = intervalSuggestion; + } + + @Override + public String toString() { + String superString = super.toString(); + StringBuilder sb = new StringBuilder(superString); + if (type != null) { + sb.append(" type: "); + sb.append(type.getName()); + } + + if (aspect != null) { + sb.append(" aspect: "); + sb.append(aspect.getName()); + } + + if (intervalSuggestion != null) { + sb.append(" interval suggestion: "); + sb.append(intervalSuggestion.getInterval()); + sb.append(intervalSuggestion.getUnit()); + } + + return sb.toString(); + } +} diff --git a/src/main/java/org/opensearch/timeseries/common/exception/VersionException.java-e b/src/main/java/org/opensearch/timeseries/common/exception/VersionException.java-e new file mode 100644 index 000000000..b9fac314c --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/common/exception/VersionException.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.common.exception; + +/** + * AD version incompatible exception. + */ +public class VersionException extends TimeSeriesException { + + public VersionException(String message) { + super(message); + } + + public VersionException(String configId, String message) { + super(configId, message); + } +} diff --git a/src/main/java/org/opensearch/timeseries/constant/CommonMessages.java-e b/src/main/java/org/opensearch/timeseries/constant/CommonMessages.java-e new file mode 100644 index 000000000..393248237 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/constant/CommonMessages.java-e @@ -0,0 +1,88 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.constant; + +import java.util.Locale; + +public class CommonMessages { + // ====================================== + // Validation message + // ====================================== + public static String NEGATIVE_TIME_CONFIGURATION = "should be non-negative"; + public static String INVALID_RESULT_INDEX_MAPPING = "Result index mapping is not correct for index: "; + public static String EMPTY_NAME = "name should be set"; + public static String NULL_TIME_FIELD = "Time field should be set"; + public static String EMPTY_INDICES = "Indices should be set"; + + public static String getTooManyCategoricalFieldErr(int limit) { + return String.format(Locale.ROOT, CommonMessages.TOO_MANY_CATEGORICAL_FIELD_ERR_MSG_FORMAT, limit); + } + + public static final String TOO_MANY_CATEGORICAL_FIELD_ERR_MSG_FORMAT = + "Currently we only support up to %d categorical field/s in order to bound system resource consumption."; + + public static String CAN_NOT_FIND_RESULT_INDEX = "Can't find result index "; + public static String INVALID_CHAR_IN_RESULT_INDEX_NAME = + "Result index name has invalid character. Valid characters are a-z, 0-9, -(hyphen) and _(underscore)"; + public static String FAIL_TO_VALIDATE = "failed to validate"; + public static String INVALID_TIMESTAMP = "Timestamp field: (%s) must be of type date"; + public static String NON_EXISTENT_TIMESTAMP = "Timestamp field: (%s) is not found in index mapping"; + public static String INVALID_NAME = "Valid characters for name are a-z, A-Z, 0-9, -(hyphen), _(underscore) and .(period)"; + // change this error message to make it compatible with old version's integration(nexus) test + public static String FAIL_TO_FIND_CONFIG_MSG = "Can't find config with id: "; + public static final String CAN_NOT_CHANGE_CATEGORY_FIELD = "Can't change category field"; + public static final String CAN_NOT_CHANGE_CUSTOM_RESULT_INDEX = "Can't change custom result index"; + public static final String CATEGORICAL_FIELD_TYPE_ERR_MSG = "A categorical field must be of type keyword or ip."; + // Modifying message for FEATURE below may break the parseADValidationException method of ValidateAnomalyDetectorTransportAction + public static final String FEATURE_INVALID_MSG_PREFIX = "Feature has an invalid query"; + public static final String FEATURE_WITH_EMPTY_DATA_MSG = FEATURE_INVALID_MSG_PREFIX + " returning empty aggregated data: "; + public static final String FEATURE_WITH_INVALID_QUERY_MSG = FEATURE_INVALID_MSG_PREFIX + " causing a runtime exception: "; + public static final String UNKNOWN_SEARCH_QUERY_EXCEPTION_MSG = + "Feature has an unknown exception caught while executing the feature query: "; + public static String DUPLICATE_FEATURE_AGGREGATION_NAMES = "Config has duplicate feature aggregation query names: "; + + // ====================================== + // Index message + // ====================================== + public static final String CREATE_INDEX_NOT_ACKNOWLEDGED = "Create index %S not acknowledged by OpenSearch core"; + public static final String SUCCESS_SAVING_RESULT_MSG = "Result saved successfully."; + public static final String CANNOT_SAVE_RESULT_ERR_MSG = "Cannot save results due to write block."; + + // ====================================== + // Resource constraints + // ====================================== + public static final String MEMORY_CIRCUIT_BROKEN_ERR_MSG = + "The total OpenSearch memory usage exceeds our threshold, opening the AD memory circuit."; + + // ====================================== + // Transport + // ====================================== + public static final String INVALID_TIMESTAMP_ERR_MSG = "timestamp is invalid"; + + // ====================================== + // transport/restful client + // ====================================== + public static final String WAIT_ERR_MSG = "Exception in waiting for result"; + public static final String ALL_FEATURES_DISABLED_ERR_MSG = + "Having trouble querying data because all of your features have been disabled."; + // We need this invalid query tag to show proper error message on frontend + // refer to AD Dashboard code: https://tinyurl.com/8b5n8hat + public static final String INVALID_SEARCH_QUERY_MSG = "Invalid search query."; + public static final String NO_REQUESTS_ADDED_ERR = "no requests added"; + + // ====================================== + // rate limiting worker + // ====================================== + public static final String BUG_RESPONSE = "We might have bugs."; + public static final String MEMORY_LIMIT_EXCEEDED_ERR_MSG = "Models memory usage exceeds our limit."; + +} diff --git a/src/main/java/org/opensearch/timeseries/constant/CommonName.java b/src/main/java/org/opensearch/timeseries/constant/CommonName.java index 129378ceb..0b997ea5d 100644 --- a/src/main/java/org/opensearch/timeseries/constant/CommonName.java +++ b/src/main/java/org/opensearch/timeseries/constant/CommonName.java @@ -106,4 +106,11 @@ public class CommonName { public static final String MODEL_ID_KEY = "model_id"; public static final String TASK_ID_FIELD = "task_id"; public static final String ENTITY_ID_FIELD = "entity_id"; + + // ====================================== + // plugin info + // ====================================== + public static final String TIME_SERIES_PLUGIN_NAME = "opensearch-time-series-analytics"; + public static final String TIME_SERIES_PLUGIN_NAME_FOR_TEST = "org.opensearch.timeseries.TimeSeriesAnalyticsPlugin"; + public static final String TIME_SERIES_PLUGIN_VERSION_FOR_TEST = "NA"; } diff --git a/src/main/java/org/opensearch/timeseries/constant/CommonName.java-e b/src/main/java/org/opensearch/timeseries/constant/CommonName.java-e new file mode 100644 index 000000000..0b997ea5d --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/constant/CommonName.java-e @@ -0,0 +1,116 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.constant; + +public class CommonName { + + // ====================================== + // Index mapping + // ====================================== + // Elastic mapping type + public static final String MAPPING_TYPE = "_doc"; + // used for updating mapping + public static final String SCHEMA_VERSION_FIELD = "schema_version"; + + // Used to fetch mapping + public static final String TYPE = "type"; + public static final String KEYWORD_TYPE = "keyword"; + public static final String IP_TYPE = "ip"; + public static final String DATE_TYPE = "date"; + + // ====================================== + // Index name + // ====================================== + // config index. We are reusing ad detector index. + public static final String CONFIG_INDEX = ".opendistro-anomaly-detectors"; + + // job index. We are reusing ad job index. + public static final String JOB_INDEX = ".opendistro-anomaly-detector-jobs"; + + // ====================================== + // Validation + // ====================================== + public static final String MODEL_ASPECT = "model"; + public static final String CONFIG_ID_MISSING_MSG = "config ID is missing"; + + // ====================================== + // Used for custom forecast result index + // ====================================== + public static final String PROPERTIES = "properties"; + + // ====================================== + // Used in toXContent + // ====================================== + public static final String START_JSON_KEY = "start"; + public static final String END_JSON_KEY = "end"; + public static final String ENTITIES_JSON_KEY = "entities"; + public static final String ENTITY_KEY = "entity"; + public static final String VALUE_JSON_KEY = "value"; + public static final String VALUE_LIST_FIELD = "value_list"; + public static final String FEATURE_DATA_FIELD = "feature_data"; + public static final String DATA_START_TIME_FIELD = "data_start_time"; + public static final String DATA_END_TIME_FIELD = "data_end_time"; + public static final String EXECUTION_START_TIME_FIELD = "execution_start_time"; + public static final String EXECUTION_END_TIME_FIELD = "execution_end_time"; + public static final String ERROR_FIELD = "error"; + public static final String ENTITY_FIELD = "entity"; + public static final String USER_FIELD = "user"; + public static final String CONFIDENCE_FIELD = "confidence"; + public static final String DATA_QUALITY_FIELD = "data_quality"; + // MODEL_ID_FIELD can be used in profile and stats API as well + public static final String MODEL_ID_FIELD = "model_id"; + public static final String TIMESTAMP = "timestamp"; + public static final String FIELD_MODEL = "model"; + + // entity sample in checkpoint. + // kept for bwc purpose + public static final String ENTITY_SAMPLE = "sp"; + // current key for entity samples + public static final String ENTITY_SAMPLE_QUEUE = "samples"; + + // ====================================== + // Profile name + // ====================================== + public static final String MODEL_SIZE_IN_BYTES = "model_size_in_bytes"; + + // ====================================== + // Used for backward-compatibility in messaging + // ====================================== + public static final String EMPTY_FIELD = ""; + + // ====================================== + // Query + // ====================================== + // Used in finding the max timestamp + public static final String AGG_NAME_MAX_TIME = "max_timefield"; + // Used in finding the min timestamp + public static final String AGG_NAME_MIN_TIME = "min_timefield"; + // date histogram aggregation name + public static final String DATE_HISTOGRAM = "date_histogram"; + // feature aggregation name + public static final String FEATURE_AGGS = "feature_aggs"; + + // ====================================== + // Used in toXContent + // ====================================== + public static final String CONFIG_ID_KEY = "config_id"; + public static final String MODEL_ID_KEY = "model_id"; + public static final String TASK_ID_FIELD = "task_id"; + public static final String ENTITY_ID_FIELD = "entity_id"; + + // ====================================== + // plugin info + // ====================================== + public static final String TIME_SERIES_PLUGIN_NAME = "opensearch-time-series-analytics"; + public static final String TIME_SERIES_PLUGIN_NAME_FOR_TEST = "org.opensearch.timeseries.TimeSeriesAnalyticsPlugin"; + public static final String TIME_SERIES_PLUGIN_VERSION_FOR_TEST = "NA"; +} diff --git a/src/main/java/org/opensearch/timeseries/constant/CommonValue.java-e b/src/main/java/org/opensearch/timeseries/constant/CommonValue.java-e new file mode 100644 index 000000000..6f05f59d0 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/constant/CommonValue.java-e @@ -0,0 +1,12 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.constant; + +public class CommonValue { + // unknown or no schema version + public static Integer NO_SCHEMA_VERSION = 0; + +} diff --git a/src/main/java/org/opensearch/timeseries/dataprocessor/FixedValueImputer.java-e b/src/main/java/org/opensearch/timeseries/dataprocessor/FixedValueImputer.java-e new file mode 100644 index 000000000..9b8f6bf21 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/dataprocessor/FixedValueImputer.java-e @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.dataprocessor; + +import java.util.Arrays; + +/** + * fixing missing value (denoted using Double.NaN) using a fixed set of specified values. + * The 2nd parameter of interpolate is ignored as we infer the number of imputed values + * using the number of Double.NaN. + */ +public class FixedValueImputer extends Imputer { + private double[] fixedValue; + + public FixedValueImputer(double[] fixedValue) { + this.fixedValue = fixedValue; + } + + /** + * Given an array of samples, fill with given value. + * We will ignore the rest of samples beyond the 2nd element. + * + * @return an imputed array of size numImputed + */ + @Override + public double[][] impute(double[][] samples, int numImputed) { + int numFeatures = samples.length; + double[][] imputed = new double[numFeatures][numImputed]; + + for (int featureIndex = 0; featureIndex < numFeatures; featureIndex++) { + imputed[featureIndex] = singleFeatureInterpolate(samples[featureIndex], numImputed, fixedValue[featureIndex]); + } + return imputed; + } + + private double[] singleFeatureInterpolate(double[] samples, int numInterpolants, double defaultVal) { + return Arrays.stream(samples).map(d -> Double.isNaN(d) ? defaultVal : d).toArray(); + } + + @Override + protected double[] singleFeatureImpute(double[] samples, int numInterpolants) { + throw new UnsupportedOperationException("The operation is not supported"); + } +} diff --git a/src/main/java/org/opensearch/timeseries/dataprocessor/ImputationMethod.java-e b/src/main/java/org/opensearch/timeseries/dataprocessor/ImputationMethod.java-e new file mode 100644 index 000000000..90494862c --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/dataprocessor/ImputationMethod.java-e @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.dataprocessor; + +public enum ImputationMethod { + /** + * This method replaces all missing values with 0's. It's a simple approach, but it may introduce bias if the data is not centered around zero. + */ + ZERO, + /** + * This method replaces missing values with a predefined set of values. The values are the same for each input dimension, and they need to be specified by the user. + */ + FIXED_VALUES, + /** + * This method replaces missing values with the last known value in the respective input dimension. It's a commonly used method for time series data, where temporal continuity is expected. + */ + PREVIOUS, + /** + * This method estimates missing values by interpolating linearly between known values in the respective input dimension. This method assumes that the data follows a linear trend. + */ + LINEAR +} diff --git a/src/main/java/org/opensearch/timeseries/dataprocessor/ImputationOption.java b/src/main/java/org/opensearch/timeseries/dataprocessor/ImputationOption.java index e073d1316..9098aac14 100644 --- a/src/main/java/org/opensearch/timeseries/dataprocessor/ImputationOption.java +++ b/src/main/java/org/opensearch/timeseries/dataprocessor/ImputationOption.java @@ -5,7 +5,7 @@ package org.opensearch.timeseries.dataprocessor; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; import java.util.ArrayList; @@ -15,9 +15,9 @@ import java.util.Objects; import java.util.Optional; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; diff --git a/src/main/java/org/opensearch/timeseries/dataprocessor/ImputationOption.java-e b/src/main/java/org/opensearch/timeseries/dataprocessor/ImputationOption.java-e new file mode 100644 index 000000000..9098aac14 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/dataprocessor/ImputationOption.java-e @@ -0,0 +1,147 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.dataprocessor; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.Optional; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +public class ImputationOption implements Writeable, ToXContent { + // field name in toXContent + public static final String METHOD_FIELD = "method"; + public static final String DEFAULT_FILL_FIELD = "defaultFill"; + public static final String INTEGER_SENSITIVE_FIELD = "integerSensitive"; + + private final ImputationMethod method; + private final Optional defaultFill; + private final boolean integerSentive; + + public ImputationOption(ImputationMethod method, Optional defaultFill, boolean integerSentive) { + this.method = method; + this.defaultFill = defaultFill; + this.integerSentive = integerSentive; + } + + public ImputationOption(ImputationMethod method) { + this(method, Optional.empty(), false); + } + + public ImputationOption(StreamInput in) throws IOException { + this.method = in.readEnum(ImputationMethod.class); + if (in.readBoolean()) { + this.defaultFill = Optional.of(in.readDoubleArray()); + } else { + this.defaultFill = Optional.empty(); + } + this.integerSentive = in.readBoolean(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeEnum(method); + if (defaultFill.isEmpty()) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeDoubleArray(defaultFill.get()); + } + out.writeBoolean(integerSentive); + } + + public static ImputationOption parse(XContentParser parser) throws IOException { + ImputationMethod method = ImputationMethod.ZERO; + List defaultFill = null; + Boolean integerSensitive = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case METHOD_FIELD: + method = ImputationMethod.valueOf(parser.text().toUpperCase(Locale.ROOT)); + break; + case DEFAULT_FILL_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + defaultFill = new ArrayList<>(); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + defaultFill.add(parser.doubleValue()); + } + break; + case INTEGER_SENSITIVE_FIELD: + integerSensitive = parser.booleanValue(); + break; + default: + break; + } + } + return new ImputationOption( + method, + Optional.ofNullable(defaultFill).map(list -> list.stream().mapToDouble(Double::doubleValue).toArray()), + integerSensitive + ); + } + + public XContentBuilder toXContent(XContentBuilder builder) throws IOException { + return toXContent(builder, ToXContent.EMPTY_PARAMS); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + + builder.field(METHOD_FIELD, method); + + if (!defaultFill.isEmpty()) { + builder.array(DEFAULT_FILL_FIELD, defaultFill.get()); + } + builder.field(INTEGER_SENSITIVE_FIELD, integerSentive); + return xContentBuilder.endObject(); + } + + @Override + public boolean equals(Object o) { + if (o == this) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + ImputationOption other = (ImputationOption) o; + return method == other.method + && (defaultFill.isEmpty() ? other.defaultFill.isEmpty() : Arrays.equals(defaultFill.get(), other.defaultFill.get())) + && integerSentive == other.integerSentive; + } + + @Override + public int hashCode() { + return Objects.hash(method, (defaultFill.isEmpty() ? 0 : Arrays.hashCode(defaultFill.get())), integerSentive); + } + + public ImputationMethod getMethod() { + return method; + } + + public Optional getDefaultFill() { + return defaultFill; + } + + public boolean isIntegerSentive() { + return integerSentive; + } +} diff --git a/src/main/java/org/opensearch/timeseries/dataprocessor/Imputer.java-e b/src/main/java/org/opensearch/timeseries/dataprocessor/Imputer.java-e new file mode 100644 index 000000000..4e885421c --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/dataprocessor/Imputer.java-e @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.dataprocessor; + +/* + * An object for imputing feature vectors. + * + * In certain situations, due to time and compute cost, we are only allowed to + * query a sparse sample of data points / feature vectors from a cluster. + * However, we need a large sample of feature vectors in order to train our + * anomaly detection algorithms. An Imputer approximates the data points + * between a given, ordered list of samples. + */ +public abstract class Imputer { + + /** + * Imputes the given sample feature vectors. + * + * Computes a list `numImputed` feature vectors using the ordered list + * of `numSamples` input sample vectors where each sample vector has size + * `numFeatures`. + * + * + * @param samples A `numFeatures x numSamples` list of feature vectors. + * @param numImputed The desired number of imputed vectors. + * @return A `numFeatures x numImputed` list of feature vectors. + */ + public double[][] impute(double[][] samples, int numImputed) { + int numFeatures = samples.length; + double[][] interpolants = new double[numFeatures][numImputed]; + + for (int featureIndex = 0; featureIndex < numFeatures; featureIndex++) { + interpolants[featureIndex] = singleFeatureImpute(samples[featureIndex], numImputed); + } + return interpolants; + } + + /** + * compute per-feature impute value + * @param samples input array + * @param numImputed number of elements in the return array + * @return input array with missing values imputed + */ + protected abstract double[] singleFeatureImpute(double[] samples, int numImputed); +} diff --git a/src/main/java/org/opensearch/timeseries/dataprocessor/LinearUniformImputer.java-e b/src/main/java/org/opensearch/timeseries/dataprocessor/LinearUniformImputer.java-e new file mode 100644 index 000000000..2fa3fd651 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/dataprocessor/LinearUniformImputer.java-e @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.dataprocessor; + +import java.util.Arrays; + +import com.google.common.math.DoubleMath; + +/** + * A piecewise linear imputer with uniformly spaced points. + * + * The LinearUniformImputer constructs a piecewise linear imputation on + * the input list of sample feature vectors. That is, between every consecutive + * pair of points we construct a linear imputation. The linear imputation + * is computed on a per-feature basis. + * + */ +public class LinearUniformImputer extends Imputer { + // if true, impute integral/floating-point results: when all samples are integral, + // the results are integral. Else, the results are floating points. + private boolean integerSensitive; + + public LinearUniformImputer(boolean integerSensitive) { + this.integerSensitive = integerSensitive; + } + + /* + * Piecewise linearly impute the given sample of one-dimensional + * features. + * + * Computes a list `numImputed` features using the ordered list of + * `numSamples` input one-dimensional samples. The imputed features are + * computing using a piecewise linear imputation. + * + * @param samples A `numSamples` sized list of sample features. + * @param numImputed The desired number of imputed features. + * @return A `numImputed` sized array of imputed features. + */ + @Override + public double[] singleFeatureImpute(double[] samples, int numImputed) { + int numSamples = samples.length; + double[] imputedValues = new double[numImputed]; + + if (numSamples == 0) { + imputedValues = new double[0]; + } else if (numSamples == 1) { + Arrays.fill(imputedValues, samples[0]); + } else { + /* assume the piecewise linear imputation between the samples is a + parameterized curve f(t) for t in [0, 1]. Each pair of samples + determines a interval [t_i, t_(i+1)]. For each imputed value we + determine which interval it lies inside and then scale the value of t, + accordingly to compute the imputed value. + + for numerical stability reasons we omit processing the final + imputed value in this loop since this last imputed value is always equal + to the last sample. + */ + for (int imputedIndex = 0; imputedIndex < (numImputed - 1); imputedIndex++) { + double tGlobal = (imputedIndex) / (numImputed - 1.0); + double tInterval = tGlobal * (numSamples - 1.0); + int intervalIndex = (int) Math.floor(tInterval); + tInterval -= intervalIndex; + + double leftSample = samples[intervalIndex]; + double rightSample = samples[intervalIndex + 1]; + double imputed = (1.0 - tInterval) * leftSample + tInterval * rightSample; + imputedValues[imputedIndex] = imputed; + } + + // the final imputed value is always the final sample + imputedValues[numImputed - 1] = samples[numSamples - 1]; + } + if (integerSensitive && Arrays.stream(samples).allMatch(DoubleMath::isMathematicalInteger)) { + imputedValues = Arrays.stream(imputedValues).map(Math::rint).toArray(); + } + return imputedValues; + } +} diff --git a/src/main/java/org/opensearch/timeseries/dataprocessor/PreviousValueImputer.java-e b/src/main/java/org/opensearch/timeseries/dataprocessor/PreviousValueImputer.java-e new file mode 100644 index 000000000..e91c90814 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/dataprocessor/PreviousValueImputer.java-e @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.dataprocessor; + +/** + * Given an array of samples, fill missing values (represented using Double.NaN) + * with previous value. + * The return array may be smaller than the input array as we remove leading missing + * values after interpolation. If the first sample is Double.NaN + * as there is no last known value to fill in. + * The 2nd parameter of interpolate is ignored as we infer the number of imputed values + * using the number of Double.NaN. + * + */ +public class PreviousValueImputer extends Imputer { + + @Override + protected double[] singleFeatureImpute(double[] samples, int numInterpolants) { + int numSamples = samples.length; + double[] interpolants = new double[numSamples]; + + if (numSamples > 0) { + System.arraycopy(samples, 0, interpolants, 0, samples.length); + if (numSamples > 1) { + double lastKnownValue = Double.NaN; + for (int interpolantIndex = 0; interpolantIndex < numSamples; interpolantIndex++) { + if (Double.isNaN(interpolants[interpolantIndex])) { + if (!Double.isNaN(lastKnownValue)) { + interpolants[interpolantIndex] = lastKnownValue; + } + } else { + lastKnownValue = interpolants[interpolantIndex]; + } + } + } + } + return interpolants; + } +} diff --git a/src/main/java/org/opensearch/timeseries/dataprocessor/ZeroImputer.java-e b/src/main/java/org/opensearch/timeseries/dataprocessor/ZeroImputer.java-e new file mode 100644 index 000000000..1d0656de1 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/dataprocessor/ZeroImputer.java-e @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.dataprocessor; + +import java.util.Arrays; + +/** + * fixing missing value (denoted using Double.NaN) using 0. + * The 2nd parameter of impute is ignored as we infer the number + * of imputed values using the number of Double.NaN. + */ +public class ZeroImputer extends Imputer { + + @Override + public double[] singleFeatureImpute(double[] samples, int numInterpolants) { + return Arrays.stream(samples).map(d -> Double.isNaN(d) ? 0.0 : d).toArray(); + } +} diff --git a/src/main/java/org/opensearch/timeseries/function/ExecutorFunction.java-e b/src/main/java/org/opensearch/timeseries/function/ExecutorFunction.java-e new file mode 100644 index 000000000..90cd93cfb --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/function/ExecutorFunction.java-e @@ -0,0 +1,23 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.function; + +@FunctionalInterface +public interface ExecutorFunction { + + /** + * Performs this operation. + * + * Notes: don't forget to send back responses via channel if you process response with this method. + */ + void execute(); +} diff --git a/src/main/java/org/opensearch/timeseries/function/ThrowingConsumer.java-e b/src/main/java/org/opensearch/timeseries/function/ThrowingConsumer.java-e new file mode 100644 index 000000000..8a7210f01 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/function/ThrowingConsumer.java-e @@ -0,0 +1,23 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.function; + +/** + * A consumer that can throw checked exception + * + * @param method parameter type + * @param Exception type + */ +@FunctionalInterface +public interface ThrowingConsumer { + void accept(T t) throws E; +} diff --git a/src/main/java/org/opensearch/timeseries/function/ThrowingSupplier.java-e b/src/main/java/org/opensearch/timeseries/function/ThrowingSupplier.java-e new file mode 100644 index 000000000..a56f513a8 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/function/ThrowingSupplier.java-e @@ -0,0 +1,23 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.function; + +/** + * A supplier that can throw checked exception + * + * @param method parameter type + * @param Exception type + */ +@FunctionalInterface +public interface ThrowingSupplier { + T get() throws E; +} diff --git a/src/main/java/org/opensearch/timeseries/function/ThrowingSupplierWrapper.java-e b/src/main/java/org/opensearch/timeseries/function/ThrowingSupplierWrapper.java-e new file mode 100644 index 000000000..c57b11d33 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/function/ThrowingSupplierWrapper.java-e @@ -0,0 +1,43 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.function; + +import java.util.function.Supplier; + +public class ThrowingSupplierWrapper { + /* + * Private constructor to avoid Jacoco complaining about public constructor + * not covered: https://tinyurl.com/yetc7tra + */ + private ThrowingSupplierWrapper() {} + + /** + * Utility method to use a method throwing checked exception inside a place + * that does not allow throwing the corresponding checked exception (e.g., + * enum initialization). + * Convert the checked exception thrown by by throwingConsumer to a RuntimeException + * so that the compiler won't complain. + * @param the method's return type + * @param throwingSupplier the method reference that can throw checked exception + * @return converted method reference + */ + public static Supplier throwingSupplierWrapper(ThrowingSupplier throwingSupplier) { + + return () -> { + try { + return throwingSupplier.get(); + } catch (Exception ex) { + throw new RuntimeException(ex); + } + }; + } +} diff --git a/src/main/java/org/opensearch/timeseries/indices/IndexManagement.java b/src/main/java/org/opensearch/timeseries/indices/IndexManagement.java index d342f3f85..d191be72d 100644 --- a/src/main/java/org/opensearch/timeseries/indices/IndexManagement.java +++ b/src/main/java/org/opensearch/timeseries/indices/IndexManagement.java @@ -58,13 +58,13 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Strings; -import org.opensearch.common.bytes.BytesArray; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.InjectSecurity; +import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.core.xcontent.XContentParser.Token; diff --git a/src/main/java/org/opensearch/timeseries/indices/IndexManagement.java-e b/src/main/java/org/opensearch/timeseries/indices/IndexManagement.java-e new file mode 100644 index 000000000..b3c49a78f --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/indices/IndexManagement.java-e @@ -0,0 +1,1002 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.indices; + +import static org.opensearch.cluster.metadata.IndexMetadata.SETTING_REPLICATION_TYPE; +import static org.opensearch.indices.replication.common.ReplicationType.DOCUMENT; +import static org.opensearch.timeseries.constant.CommonMessages.CAN_NOT_FIND_RESULT_INDEX; + +import java.io.IOException; +import java.net.URL; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.EnumMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.ExceptionsHelper; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.admin.cluster.state.ClusterStateRequest; +import org.opensearch.action.admin.indices.alias.Alias; +import org.opensearch.action.admin.indices.alias.get.GetAliasesRequest; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +import org.opensearch.action.admin.indices.mapping.put.PutMappingRequest; +import org.opensearch.action.admin.indices.rollover.RolloverRequest; +import org.opensearch.action.admin.indices.settings.get.GetSettingsAction; +import org.opensearch.action.admin.indices.settings.get.GetSettingsRequest; +import org.opensearch.action.admin.indices.settings.get.GetSettingsResponse; +import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.support.GroupedActionListener; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.cluster.LocalNodeClusterManagerListener; +import org.opensearch.cluster.metadata.AliasMetadata; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Strings; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.InjectSecurity; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.xcontent.XContentParser.Token; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.threadpool.Scheduler; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.constant.CommonValue; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +import com.google.common.base.Charsets; +import com.google.common.io.Resources; + +public abstract class IndexManagement & TimeSeriesIndex> implements LocalNodeClusterManagerListener { + private static final Logger logger = LogManager.getLogger(IndexManagement.class); + + // minimum shards of the job index + public static int minJobIndexReplicas = 1; + // maximum shards of the job index + public static int maxJobIndexReplicas = 20; + // package private for testing + public static final String META = "_meta"; + public static final String SCHEMA_VERSION = "schema_version"; + + protected ClusterService clusterService; + protected final Client client; + protected final AdminClient adminClient; + protected final ThreadPool threadPool; + protected DiscoveryNodeFilterer nodeFilter; + // index settings + protected final Settings settings; + // don't retry updating endlessly. Can be annoying if there are too many exception logs. + protected final int maxUpdateRunningTimes; + + // whether all index have the correct mappings + protected boolean allMappingUpdated; + // whether all index settings are updated + protected boolean allSettingUpdated; + // we only want one update at a time + protected final AtomicBoolean updateRunning; + // the number of times updates run + protected int updateRunningTimes; + private final Class indexType; + // keep track of whether the mapping version is up-to-date + protected EnumMap indexStates; + protected int maxPrimaryShards; + private Scheduler.Cancellable scheduledRollover = null; + protected volatile TimeValue historyRolloverPeriod; + protected volatile Long historyMaxDocs; + protected volatile TimeValue historyRetentionPeriod; + // result index mapping to valida custom index + private Map RESULT_FIELD_CONFIGS; + private String resultMapping; + + protected class IndexState { + // keep track of whether the mapping version is up-to-date + public Boolean mappingUpToDate; + // keep track of whether the setting needs to change + public Boolean settingUpToDate; + // record schema version reading from the mapping file + public Integer schemaVersion; + + public IndexState(String mappingFile) { + this.mappingUpToDate = false; + this.settingUpToDate = false; + this.schemaVersion = IndexManagement.parseSchemaVersion(mappingFile); + } + } + + protected IndexManagement( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + Settings settings, + DiscoveryNodeFilterer nodeFilter, + int maxUpdateRunningTimes, + Class indexType, + int maxPrimaryShards, + TimeValue historyRolloverPeriod, + Long historyMaxDocs, + TimeValue historyRetentionPeriod, + String resultMapping + ) + throws IOException { + this.client = client; + this.adminClient = client.admin(); + this.clusterService = clusterService; + this.threadPool = threadPool; + this.clusterService.addLocalNodeClusterManagerListener(this); + this.nodeFilter = nodeFilter; + this.settings = Settings.builder().put("index.hidden", true).build(); + this.maxUpdateRunningTimes = maxUpdateRunningTimes; + this.indexType = indexType; + this.maxPrimaryShards = maxPrimaryShards; + this.historyRolloverPeriod = historyRolloverPeriod; + this.historyMaxDocs = historyMaxDocs; + this.historyRetentionPeriod = historyRetentionPeriod; + + this.allMappingUpdated = false; + this.allSettingUpdated = false; + this.updateRunning = new AtomicBoolean(false); + this.updateRunningTimes = 0; + this.resultMapping = resultMapping; + } + + /** + * Alias exists or not + * @param alias Alias name + * @return true if the alias exists + */ + public boolean doesAliasExist(String alias) { + return clusterService.state().metadata().hasAlias(alias); + } + + public static Integer parseSchemaVersion(String mapping) { + try { + XContentParser xcp = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, mapping); + + while (!xcp.isClosed()) { + Token token = xcp.currentToken(); + if (token != null && token != XContentParser.Token.END_OBJECT && token != XContentParser.Token.START_OBJECT) { + if (xcp.currentName() != IndexManagement.META) { + xcp.nextToken(); + xcp.skipChildren(); + } else { + while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { + if (xcp.currentName().equals(IndexManagement.SCHEMA_VERSION)) { + + Integer version = xcp.intValue(); + if (version < 0) { + version = CommonValue.NO_SCHEMA_VERSION; + } + return version; + } else { + xcp.nextToken(); + } + } + + } + } + xcp.nextToken(); + } + return CommonValue.NO_SCHEMA_VERSION; + } catch (Exception e) { + // since this method is called in the constructor that is called by TimeSeriesAnalyticsPlugin.createComponents, + // we cannot throw checked exception + throw new RuntimeException(e); + } + } + + protected static Integer getIntegerSetting(GetSettingsResponse settingsResponse, String settingKey) { + Integer value = null; + for (Settings settings : settingsResponse.getIndexToSettings().values()) { + value = settings.getAsInt(settingKey, null); + if (value != null) { + break; + } + } + return value; + } + + protected static String getStringSetting(GetSettingsResponse settingsResponse, String settingKey) { + String value = null; + for (Settings settings : settingsResponse.getIndexToSettings().values()) { + value = settings.get(settingKey, null); + if (value != null) { + break; + } + } + return value; + } + + public boolean doesIndexExist(String indexName) { + return clusterService.state().metadata().hasIndex(indexName); + } + + protected static String getMappings(String mappingFileRelativePath) throws IOException { + URL url = IndexManagement.class.getClassLoader().getResource(mappingFileRelativePath); + return Resources.toString(url, Charsets.UTF_8); + } + + protected void choosePrimaryShards(CreateIndexRequest request, boolean hiddenIndex) { + request + .settings( + Settings + .builder() + // put 1 primary shards per hot node if possible + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, getNumberOfPrimaryShards()) + // 1 replica for better search performance and fail-over + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 1) + .put("index.hidden", hiddenIndex) + ); + } + + protected void deleteOldHistoryIndices(String indexPattern, TimeValue historyRetentionPeriod) { + Set candidates = new HashSet(); + + ClusterStateRequest clusterStateRequest = new ClusterStateRequest() + .clear() + .indices(indexPattern) + .metadata(true) + .local(true) + .indicesOptions(IndicesOptions.strictExpand()); + + adminClient.cluster().state(clusterStateRequest, ActionListener.wrap(clusterStateResponse -> { + String latestToDelete = null; + long latest = Long.MIN_VALUE; + for (IndexMetadata indexMetaData : clusterStateResponse.getState().metadata().indices().values()) { + long creationTime = indexMetaData.getCreationDate(); + if ((Instant.now().toEpochMilli() - creationTime) > historyRetentionPeriod.millis()) { + String indexName = indexMetaData.getIndex().getName(); + candidates.add(indexName); + if (latest < creationTime) { + latest = creationTime; + latestToDelete = indexName; + } + } + } + if (candidates.size() > 1) { + // delete all indices except the last one because the last one may contain docs newer than the retention period + candidates.remove(latestToDelete); + String[] toDelete = candidates.toArray(Strings.EMPTY_ARRAY); + DeleteIndexRequest deleteIndexRequest = new DeleteIndexRequest(toDelete); + adminClient.indices().delete(deleteIndexRequest, ActionListener.wrap(deleteIndexResponse -> { + if (!deleteIndexResponse.isAcknowledged()) { + logger.error("Could not delete one or more result indices: {}. Retrying one by one.", Arrays.toString(toDelete)); + deleteIndexIteration(toDelete); + } else { + logger.info("Succeeded in deleting expired result indices: {}.", Arrays.toString(toDelete)); + } + }, exception -> { + logger.error("Failed to delete expired result indices: {}.", Arrays.toString(toDelete)); + deleteIndexIteration(toDelete); + })); + } + }, exception -> { logger.error("Fail to delete result indices", exception); })); + } + + protected void deleteIndexIteration(String[] toDelete) { + for (String index : toDelete) { + DeleteIndexRequest singleDeleteRequest = new DeleteIndexRequest(index); + adminClient.indices().delete(singleDeleteRequest, ActionListener.wrap(singleDeleteResponse -> { + if (!singleDeleteResponse.isAcknowledged()) { + logger.error("Retrying deleting {} does not succeed.", index); + } + }, exception -> { + if (exception instanceof IndexNotFoundException) { + logger.info("{} was already deleted.", index); + } else { + logger.error(new ParameterizedMessage("Retrying deleting {} does not succeed.", index), exception); + } + })); + } + } + + @SuppressWarnings("unchecked") + protected void shouldUpdateConcreteIndex(String concreteIndex, Integer newVersion, ActionListener thenDo) { + IndexMetadata indexMeataData = clusterService.state().getMetadata().indices().get(concreteIndex); + if (indexMeataData == null) { + thenDo.onResponse(Boolean.FALSE); + return; + } + Integer oldVersion = CommonValue.NO_SCHEMA_VERSION; + + Map indexMapping = indexMeataData.mapping().getSourceAsMap(); + Object meta = indexMapping.get(IndexManagement.META); + if (meta != null && meta instanceof Map) { + Map metaMapping = (Map) meta; + Object schemaVersion = metaMapping.get(org.opensearch.timeseries.constant.CommonName.SCHEMA_VERSION_FIELD); + if (schemaVersion instanceof Integer) { + oldVersion = (Integer) schemaVersion; + } + } + thenDo.onResponse(newVersion > oldVersion); + } + + protected void updateJobIndexSettingIfNecessary(String indexName, IndexState jobIndexState, ActionListener listener) { + GetSettingsRequest getSettingsRequest = new GetSettingsRequest() + .indices(indexName) + .names( + new String[] { + IndexMetadata.SETTING_NUMBER_OF_SHARDS, + IndexMetadata.SETTING_NUMBER_OF_REPLICAS, + IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS } + ); + client.execute(GetSettingsAction.INSTANCE, getSettingsRequest, ActionListener.wrap(settingResponse -> { + // auto expand setting is a range string like "1-all" + String autoExpandReplica = getStringSetting(settingResponse, IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS); + // if the auto expand setting is already there, return immediately + if (autoExpandReplica != null) { + jobIndexState.settingUpToDate = true; + logger.info(new ParameterizedMessage("Mark [{}]'s mapping up-to-date", indexName)); + listener.onResponse(null); + return; + } + Integer primaryShardsNumber = getIntegerSetting(settingResponse, IndexMetadata.SETTING_NUMBER_OF_SHARDS); + Integer replicaNumber = getIntegerSetting(settingResponse, IndexMetadata.SETTING_NUMBER_OF_REPLICAS); + if (primaryShardsNumber == null || replicaNumber == null) { + logger + .error( + new ParameterizedMessage( + "Fail to find job index's primary or replica shard number: primary [{}], replica [{}]", + primaryShardsNumber, + replicaNumber + ) + ); + // don't throw exception as we don't know how to handle it and retry next time + listener.onResponse(null); + return; + } + // at least minJobIndexReplicas + // at most maxJobIndexReplicas / primaryShardsNumber replicas. + // For example, if we have 2 primary shards, since the max number of shards are maxJobIndexReplicas (20), + // we will use 20 / 2 = 10 replicas as the upper bound of replica. + int maxExpectedReplicas = Math + .max(IndexManagement.maxJobIndexReplicas / primaryShardsNumber, IndexManagement.minJobIndexReplicas); + Settings updatedSettings = Settings + .builder() + .put(IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS, IndexManagement.minJobIndexReplicas + "-" + maxExpectedReplicas) + .build(); + final UpdateSettingsRequest updateSettingsRequest = new UpdateSettingsRequest(indexName).settings(updatedSettings); + client.admin().indices().updateSettings(updateSettingsRequest, ActionListener.wrap(response -> { + jobIndexState.settingUpToDate = true; + logger.info(new ParameterizedMessage("Mark [{}]'s mapping up-to-date", indexName)); + listener.onResponse(null); + }, listener::onFailure)); + }, e -> { + if (e instanceof IndexNotFoundException) { + // new index will be created with auto expand replica setting + jobIndexState.settingUpToDate = true; + logger.info(new ParameterizedMessage("Mark [{}]'s mapping up-to-date", indexName)); + listener.onResponse(null); + } else { + listener.onFailure(e); + } + })); + } + + /** + * Create config index if not exist. + * + * @param actionListener action called after create index + * @throws IOException IOException from {@link IndexManagement#getConfigMappings} + */ + public void initConfigIndexIfAbsent(ActionListener actionListener) throws IOException { + if (!doesConfigIndexExist()) { + initConfigIndex(actionListener); + } + } + + /** + * Create config index directly. + * + * @param actionListener action called after create index + * @throws IOException IOException from {@link IndexManagement#getConfigMappings} + */ + public void initConfigIndex(ActionListener actionListener) throws IOException { + // time series indices need RAW (e.g., we want users to be able to consume AD results as soon as possible + // and send out an alert if anomalies found). + Settings replicationSettings = Settings.builder().put(SETTING_REPLICATION_TYPE, DOCUMENT.name()).build(); + CreateIndexRequest request = new CreateIndexRequest(CommonName.CONFIG_INDEX, replicationSettings) + .mapping(getConfigMappings(), XContentType.JSON) + .settings(settings); + adminClient.indices().create(request, actionListener); + } + + /** + * Config index exist or not. + * + * @return true if config index exists + */ + public boolean doesConfigIndexExist() { + return doesIndexExist(CommonName.CONFIG_INDEX); + } + + /** + * Job index exist or not. + * + * @return true if anomaly detector job index exists + */ + public boolean doesJobIndexExist() { + return doesIndexExist(CommonName.JOB_INDEX); + } + + /** + * Get config index mapping in json format. + * + * @return config index mapping + * @throws IOException IOException if mapping file can't be read correctly + */ + public static String getConfigMappings() throws IOException { + return getMappings(TimeSeriesSettings.CONFIG_INDEX_MAPPING_FILE); + } + + /** + * Get job index mapping in json format. + * + * @return job index mapping + * @throws IOException IOException if mapping file can't be read correctly + */ + public static String getJobMappings() throws IOException { + return getMappings(TimeSeriesSettings.JOBS_INDEX_MAPPING_FILE); + } + + /** + * Createjob index. + * + * @param actionListener action called after create index + */ + public void initJobIndex(ActionListener actionListener) { + try { + // time series indices need RAW (e.g., we want users to be able to consume AD results as soon as + // possible and send out an alert if anomalies found). + Settings replicationSettings = Settings.builder().put(SETTING_REPLICATION_TYPE, DOCUMENT.name()).build(); + CreateIndexRequest request = new CreateIndexRequest(CommonName.JOB_INDEX, replicationSettings) + .mapping(getJobMappings(), XContentType.JSON); + request + .settings( + Settings + .builder() + // AD job index is small. 1 primary shard is enough + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) + // Job scheduler puts both primary and replica shards in the + // hash ring. Auto-expand the number of replicas based on the + // number of data nodes (up to 20) in the cluster so that each node can + // become a coordinating node. This is useful when customers + // scale out their cluster so that we can do adaptive scaling + // accordingly. + // At least 1 replica for fail-over. + .put(IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS, minJobIndexReplicas + "-" + maxJobIndexReplicas) + .put("index.hidden", true) + ); + adminClient.indices().create(request, actionListener); + } catch (IOException e) { + logger.error("Fail to init AD job index", e); + actionListener.onFailure(e); + } + } + + public void validateCustomResultIndexAndExecute(String resultIndex, ExecutorFunction function, ActionListener listener) { + try { + if (!isValidResultIndexMapping(resultIndex)) { + logger.warn("Can't create detector with custom result index {} as its mapping is invalid", resultIndex); + listener.onFailure(new IllegalArgumentException(CommonMessages.INVALID_RESULT_INDEX_MAPPING + resultIndex)); + return; + } + + IndexRequest indexRequest = createDummyIndexRequest(resultIndex); + + // User may have no write permission on custom result index. Talked with security plugin team, seems no easy way to verify + // if user has write permission. So just tried to write and delete a dummy forecast result to verify. + client.index(indexRequest, ActionListener.wrap(response -> { + logger.debug("Successfully wrote dummy result to result index {}", resultIndex); + client.delete(createDummyDeleteRequest(resultIndex), ActionListener.wrap(deleteResponse -> { + logger.debug("Successfully deleted dummy result from result index {}", resultIndex); + function.execute(); + }, ex -> { + logger.error("Failed to delete dummy result from result index " + resultIndex, ex); + listener.onFailure(ex); + })); + }, exception -> { + logger.error("Failed to write dummy result to result index " + resultIndex, exception); + listener.onFailure(exception); + })); + } catch (Exception e) { + logger.error("Failed to validate custom result index " + resultIndex, e); + listener.onFailure(e); + } + } + + public void update() { + if ((allMappingUpdated && allSettingUpdated) || updateRunningTimes >= maxUpdateRunningTimes || updateRunning.get()) { + return; + } + updateRunning.set(true); + updateRunningTimes++; + + // set updateRunning to false when both updateMappingIfNecessary and updateSettingIfNecessary + // stop running + final GroupedActionListener groupListeneer = new GroupedActionListener<>( + ActionListener.wrap(r -> updateRunning.set(false), exception -> { + updateRunning.set(false); + logger.error("Fail to update time series indices", exception); + }), + // 2 since we need both updateMappingIfNecessary and updateSettingIfNecessary to return + // before setting updateRunning to false + 2 + ); + + updateMappingIfNecessary(groupListeneer); + updateSettingIfNecessary(groupListeneer); + } + + private void updateSettingIfNecessary(GroupedActionListener delegateListeneer) { + if (allSettingUpdated) { + delegateListeneer.onResponse(null); + return; + } + + List updates = new ArrayList<>(); + for (IndexType index : indexType.getEnumConstants()) { + Boolean updated = indexStates.computeIfAbsent(index, k -> new IndexState(k.getMapping())).settingUpToDate; + if (Boolean.FALSE.equals(updated)) { + updates.add(index); + } + } + if (updates.size() == 0) { + allSettingUpdated = true; + delegateListeneer.onResponse(null); + return; + } + + final GroupedActionListener conglomerateListeneer = new GroupedActionListener<>( + ActionListener.wrap(r -> delegateListeneer.onResponse(null), exception -> { + delegateListeneer.onResponse(null); + logger.error("Fail to update time series indices' mappings", exception); + }), + updates.size() + ); + for (IndexType timeseriesIndex : updates) { + logger.info(new ParameterizedMessage("Check [{}]'s setting", timeseriesIndex.getIndexName())); + if (timeseriesIndex.isJobIndex()) { + updateJobIndexSettingIfNecessary( + ADIndex.JOB.getIndexName(), + indexStates.computeIfAbsent(timeseriesIndex, k -> new IndexState(k.getMapping())), + conglomerateListeneer + ); + } else { + // we don't have settings to update for other indices + IndexState indexState = indexStates.computeIfAbsent(timeseriesIndex, k -> new IndexState(k.getMapping())); + indexState.settingUpToDate = true; + logger.info(new ParameterizedMessage("Mark [{}]'s setting up-to-date", timeseriesIndex.getIndexName())); + conglomerateListeneer.onResponse(null); + } + } + } + + /** + * Update mapping if schema version changes. + */ + private void updateMappingIfNecessary(GroupedActionListener delegateListeneer) { + if (allMappingUpdated) { + delegateListeneer.onResponse(null); + return; + } + + List updates = new ArrayList<>(); + for (IndexType index : indexType.getEnumConstants()) { + Boolean updated = indexStates.computeIfAbsent(index, k -> new IndexState(k.getMapping())).mappingUpToDate; + if (Boolean.FALSE.equals(updated)) { + updates.add(index); + } + } + if (updates.size() == 0) { + allMappingUpdated = true; + delegateListeneer.onResponse(null); + return; + } + + final GroupedActionListener conglomerateListeneer = new GroupedActionListener<>( + ActionListener.wrap(r -> delegateListeneer.onResponse(null), exception -> { + delegateListeneer.onResponse(null); + logger.error("Fail to update time series indices' mappings", exception); + }), + updates.size() + ); + + for (IndexType adIndex : updates) { + logger.info(new ParameterizedMessage("Check [{}]'s mapping", adIndex.getIndexName())); + shouldUpdateIndex(adIndex, ActionListener.wrap(shouldUpdate -> { + if (shouldUpdate) { + adminClient + .indices() + .putMapping( + new PutMappingRequest().indices(adIndex.getIndexName()).source(adIndex.getMapping(), XContentType.JSON), + ActionListener.wrap(putMappingResponse -> { + if (putMappingResponse.isAcknowledged()) { + logger.info(new ParameterizedMessage("Succeeded in updating [{}]'s mapping", adIndex.getIndexName())); + markMappingUpdated(adIndex); + } else { + logger.error(new ParameterizedMessage("Fail to update [{}]'s mapping", adIndex.getIndexName())); + } + conglomerateListeneer.onResponse(null); + }, exception -> { + logger + .error( + new ParameterizedMessage( + "Fail to update [{}]'s mapping due to [{}]", + adIndex.getIndexName(), + exception.getMessage() + ) + ); + conglomerateListeneer.onFailure(exception); + }) + ); + } else { + // index does not exist or the version is already up-to-date. + // When creating index, new mappings will be used. + // We don't need to update it. + logger.info(new ParameterizedMessage("We don't need to update [{}]'s mapping", adIndex.getIndexName())); + markMappingUpdated(adIndex); + conglomerateListeneer.onResponse(null); + } + }, exception -> { + logger + .error( + new ParameterizedMessage("Fail to check whether we should update [{}]'s mapping", adIndex.getIndexName()), + exception + ); + conglomerateListeneer.onFailure(exception); + })); + + } + } + + private void markMappingUpdated(IndexType adIndex) { + IndexState indexState = indexStates.computeIfAbsent(adIndex, k -> new IndexState(k.getMapping())); + if (Boolean.FALSE.equals(indexState.mappingUpToDate)) { + indexState.mappingUpToDate = Boolean.TRUE; + logger.info(new ParameterizedMessage("Mark [{}]'s mapping up-to-date", adIndex.getIndexName())); + } + } + + private void shouldUpdateIndex(IndexType index, ActionListener thenDo) { + boolean exists = false; + if (index.isAlias()) { + exists = doesAliasExist(index.getIndexName()); + } else { + exists = doesIndexExist(index.getIndexName()); + } + if (false == exists) { + thenDo.onResponse(Boolean.FALSE); + return; + } + + Integer newVersion = indexStates.computeIfAbsent(index, k -> new IndexState(k.getMapping())).schemaVersion; + if (index.isAlias()) { + GetAliasesRequest getAliasRequest = new GetAliasesRequest() + .aliases(index.getIndexName()) + .indicesOptions(IndicesOptions.lenientExpandOpenHidden()); + adminClient.indices().getAliases(getAliasRequest, ActionListener.wrap(getAliasResponse -> { + String concreteIndex = null; + for (Map.Entry> entry : getAliasResponse.getAliases().entrySet()) { + if (false == entry.getValue().isEmpty()) { + // we assume the alias map to one concrete index, thus we can return after finding one + concreteIndex = entry.getKey(); + break; + } + } + if (concreteIndex == null) { + thenDo.onResponse(Boolean.FALSE); + return; + } + shouldUpdateConcreteIndex(concreteIndex, newVersion, thenDo); + }, exception -> logger.error(new ParameterizedMessage("Fail to get [{}]'s alias", index.getIndexName()), exception))); + } else { + shouldUpdateConcreteIndex(index.getIndexName(), newVersion, thenDo); + } + } + + /** + * + * @param index Index metadata + * @return The schema version of the given Index + */ + public int getSchemaVersion(IndexType index) { + IndexState indexState = this.indexStates.computeIfAbsent(index, k -> new IndexState(k.getMapping())); + return indexState.schemaVersion; + } + + public void initCustomResultIndexAndExecute(String resultIndex, ExecutorFunction function, ActionListener listener) { + if (!doesIndexExist(resultIndex)) { + initCustomResultIndexDirectly(resultIndex, ActionListener.wrap(response -> { + if (response.isAcknowledged()) { + logger.info("Successfully created result index {}", resultIndex); + validateCustomResultIndexAndExecute(resultIndex, function, listener); + } else { + String error = "Creating result index with mappings call not acknowledged: " + resultIndex; + logger.error(error); + listener.onFailure(new EndRunException(error, false)); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + validateCustomResultIndexAndExecute(resultIndex, function, listener); + } else { + logger.error("Failed to create result index " + resultIndex, exception); + listener.onFailure(exception); + } + })); + } else { + validateCustomResultIndexAndExecute(resultIndex, function, listener); + } + } + + public void validateCustomIndexForBackendJob( + String resultIndex, + String securityLogId, + String user, + List roles, + ExecutorFunction function, + ActionListener listener + ) { + if (!doesIndexExist(resultIndex)) { + listener.onFailure(new EndRunException(CAN_NOT_FIND_RESULT_INDEX + resultIndex, true)); + return; + } + if (!isValidResultIndexMapping(resultIndex)) { + listener.onFailure(new EndRunException("Result index mapping is not correct", true)); + return; + } + try (InjectSecurity injectSecurity = new InjectSecurity(securityLogId, settings, client.threadPool().getThreadContext())) { + injectSecurity.inject(user, roles); + ActionListener wrappedListener = ActionListener.wrap(r -> { listener.onResponse(r); }, e -> { + injectSecurity.close(); + listener.onFailure(e); + }); + validateCustomResultIndexAndExecute(resultIndex, () -> { + injectSecurity.close(); + function.execute(); + }, wrappedListener); + } catch (Exception e) { + logger.error("Failed to validate custom index for backend job " + securityLogId, e); + listener.onFailure(e); + } + } + + protected int getNumberOfPrimaryShards() { + return Math.min(nodeFilter.getNumberOfEligibleDataNodes(), maxPrimaryShards); + } + + @Override + public void onClusterManager() { + try { + // try to rollover immediately as we might be restarting the cluster + rolloverAndDeleteHistoryIndex(); + + // schedule the next rollover for approx MAX_AGE later + scheduledRollover = threadPool + .scheduleWithFixedDelay(() -> rolloverAndDeleteHistoryIndex(), historyRolloverPeriod, executorName()); + } catch (Exception e) { + // This should be run on cluster startup + logger.error("Error rollover result indices. " + "Can't rollover result until clusterManager node is restarted.", e); + } + } + + @Override + public void offClusterManager() { + if (scheduledRollover != null) { + scheduledRollover.cancel(); + } + } + + private String executorName() { + return ThreadPool.Names.MANAGEMENT; + } + + protected void rescheduleRollover() { + if (clusterService.state().getNodes().isLocalNodeElectedClusterManager()) { + if (scheduledRollover != null) { + scheduledRollover.cancel(); + } + scheduledRollover = threadPool + .scheduleWithFixedDelay(() -> rolloverAndDeleteHistoryIndex(), historyRolloverPeriod, executorName()); + } + } + + private void initResultMapping() throws IOException { + if (RESULT_FIELD_CONFIGS != null) { + // we have already initiated the field + return; + } + + Map asMap = XContentHelper.convertToMap(new BytesArray(resultMapping), false, XContentType.JSON).v2(); + Object properties = asMap.get(CommonName.PROPERTIES); + if (properties instanceof Map) { + RESULT_FIELD_CONFIGS = (Map) properties; + } else { + logger.error("Fail to read result mapping file."); + } + } + + /** + * Check if custom result index has correct index mapping. + * @param resultIndex result index + * @return true if result index mapping is valid + */ + public boolean isValidResultIndexMapping(String resultIndex) { + try { + initResultMapping(); + if (RESULT_FIELD_CONFIGS == null) { + // failed to populate the field + return false; + } + IndexMetadata indexMetadata = clusterService.state().metadata().index(resultIndex); + Map indexMapping = indexMetadata.mapping().sourceAsMap(); + String propertyName = CommonName.PROPERTIES; + if (!indexMapping.containsKey(propertyName) || !(indexMapping.get(propertyName) instanceof LinkedHashMap)) { + return false; + } + LinkedHashMap mapping = (LinkedHashMap) indexMapping.get(propertyName); + + boolean correctResultIndexMapping = true; + + for (String fieldName : RESULT_FIELD_CONFIGS.keySet()) { + Object defaultSchema = RESULT_FIELD_CONFIGS.get(fieldName); + // the field might be a map or map of map + // example: map: {type=date, format=strict_date_time||epoch_millis} + // map of map: {type=nested, properties={likelihood={type=double}, value_list={type=nested, properties={data={type=double}, + // feature_id={type=keyword}}}}} + // if it is a map of map, Object.equals can compare them regardless of order + if (!mapping.containsKey(fieldName) || !defaultSchema.equals(mapping.get(fieldName))) { + correctResultIndexMapping = false; + break; + } + } + return correctResultIndexMapping; + } catch (Exception e) { + logger.error("Failed to validate result index mapping for index " + resultIndex, e); + return false; + } + + } + + /** + * Create forecast result index if not exist. + * + * @param actionListener action called after create index + */ + public void initDefaultResultIndexIfAbsent(ActionListener actionListener) { + if (!doesDefaultResultIndexExist()) { + initDefaultResultIndexDirectly(actionListener); + } + } + + protected ActionListener markMappingUpToDate( + IndexType index, + ActionListener followingListener + ) { + return ActionListener.wrap(createdResponse -> { + if (createdResponse.isAcknowledged()) { + IndexState indexStatetate = indexStates.computeIfAbsent(index, k -> new IndexState(k.getMapping())); + if (Boolean.FALSE.equals(indexStatetate.mappingUpToDate)) { + indexStatetate.mappingUpToDate = Boolean.TRUE; + logger.info(new ParameterizedMessage("Mark [{}]'s mapping up-to-date", index.getIndexName())); + } + } + followingListener.onResponse(createdResponse); + }, exception -> followingListener.onFailure(exception)); + } + + protected void rolloverAndDeleteHistoryIndex( + String resultIndexAlias, + String allResultIndicesPattern, + String rolloverIndexPattern, + IndexType resultIndex + ) { + if (!doesDefaultResultIndexExist()) { + return; + } + + // We have to pass null for newIndexName in order to get Elastic to increment the index count. + RolloverRequest rollOverRequest = new RolloverRequest(resultIndexAlias, null); + + CreateIndexRequest createRequest = rollOverRequest.getCreateIndexRequest(); + + // time series indices need RAW (e.g., we want users to be able to consume AD results as soon as possible + // and send out an alert if anomalies found). + Settings replicationSettings = Settings.builder().put(SETTING_REPLICATION_TYPE, DOCUMENT.name()).build(); + createRequest.index(rolloverIndexPattern).settings(replicationSettings).mapping(resultMapping, XContentType.JSON); + + choosePrimaryShards(createRequest, true); + + rollOverRequest.addMaxIndexDocsCondition(historyMaxDocs * getNumberOfPrimaryShards()); + adminClient.indices().rolloverIndex(rollOverRequest, ActionListener.wrap(response -> { + if (!response.isRolledOver()) { + logger.warn("{} not rolled over. Conditions were: {}", resultIndexAlias, response.getConditionStatus()); + } else { + IndexState indexStatetate = indexStates.computeIfAbsent(resultIndex, k -> new IndexState(k.getMapping())); + indexStatetate.mappingUpToDate = true; + logger.info("{} rolled over. Conditions were: {}", resultIndexAlias, response.getConditionStatus()); + deleteOldHistoryIndices(allResultIndicesPattern, historyRetentionPeriod); + } + }, exception -> { logger.error("Fail to roll over result index", exception); })); + } + + protected void initResultIndexDirectly( + String resultIndexName, + String alias, + boolean hiddenIndex, + String resultIndexPattern, + IndexType resultIndex, + ActionListener actionListener + ) { + // time series indices need RAW (e.g., we want users to be able to consume AD results as soon as possible + // and send out an alert if anomalies found). + Settings replicationSettings = Settings.builder().put(SETTING_REPLICATION_TYPE, DOCUMENT.name()).build(); + CreateIndexRequest request = new CreateIndexRequest(resultIndexName, replicationSettings).mapping(resultMapping, XContentType.JSON); + if (alias != null) { + request.alias(new Alias(alias)); + } + choosePrimaryShards(request, hiddenIndex); + if (resultIndexPattern.equals(resultIndexName)) { + adminClient.indices().create(request, markMappingUpToDate(resultIndex, actionListener)); + } else { + adminClient.indices().create(request, actionListener); + } + } + + public abstract boolean doesCheckpointIndexExist(); + + public abstract void initCheckpointIndex(ActionListener actionListener); + + public abstract boolean doesDefaultResultIndexExist(); + + public abstract boolean doesStateIndexExist(); + + public abstract void initDefaultResultIndexDirectly(ActionListener actionListener); + + protected abstract IndexRequest createDummyIndexRequest(String resultIndex) throws IOException; + + protected abstract DeleteRequest createDummyDeleteRequest(String resultIndex) throws IOException; + + protected abstract void rolloverAndDeleteHistoryIndex(); + + public abstract void initCustomResultIndexDirectly(String resultIndex, ActionListener actionListener); + + public abstract void initStateIndex(ActionListener actionListener); +} diff --git a/src/main/java/org/opensearch/timeseries/indices/TimeSeriesIndex.java-e b/src/main/java/org/opensearch/timeseries/indices/TimeSeriesIndex.java-e new file mode 100644 index 000000000..e7364ed32 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/indices/TimeSeriesIndex.java-e @@ -0,0 +1,22 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.indices; + +public interface TimeSeriesIndex { + public String getIndexName(); + + public boolean isAlias(); + + public String getMapping(); + + public boolean isJobIndex(); +} diff --git a/src/main/java/org/opensearch/timeseries/ml/IntermediateResult.java-e b/src/main/java/org/opensearch/timeseries/ml/IntermediateResult.java-e new file mode 100644 index 000000000..9a8704842 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ml/IntermediateResult.java-e @@ -0,0 +1,86 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ml; + +import java.time.Instant; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.model.IndexableResult; + +public abstract class IntermediateResult { + protected final long totalUpdates; + protected final double rcfScore; + + public IntermediateResult(long totalUpdates, double rcfScore) { + this.totalUpdates = totalUpdates; + this.rcfScore = rcfScore; + } + + public long getTotalUpdates() { + return totalUpdates; + } + + public double getRcfScore() { + return rcfScore; + } + + @Override + public int hashCode() { + return Objects.hash(totalUpdates); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + IntermediateResult other = (IntermediateResult) obj; + return totalUpdates == other.totalUpdates && Double.doubleToLongBits(rcfScore) == Double.doubleToLongBits(other.rcfScore); + } + + /** + * convert intermediateResult into 1+ indexable results. + * @param config Config accessor + * @param dataStartInstant data start time + * @param dataEndInstant data end time + * @param executionStartInstant execution start time + * @param executionEndInstant execution end time + * @param featureData feature data + * @param entity entity info + * @param schemaVersion schema version + * @param modelId Model id + * @param taskId Task id + * @param error Error + * @return 1+ indexable results + */ + public abstract List toIndexableResults( + Config config, + Instant dataStartInstant, + Instant dataEndInstant, + Instant executionStartInstant, + Instant executionEndInstant, + List featureData, + Optional entity, + Integer schemaVersion, + String modelId, + String taskId, + String error + ); +} diff --git a/src/main/java/org/opensearch/timeseries/model/Config.java b/src/main/java/org/opensearch/timeseries/model/Config.java index 3ac400c9f..15f67d116 100644 --- a/src/main/java/org/opensearch/timeseries/model/Config.java +++ b/src/main/java/org/opensearch/timeseries/model/Config.java @@ -18,10 +18,10 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.util.Strings; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; diff --git a/src/main/java/org/opensearch/timeseries/model/Config.java-e b/src/main/java/org/opensearch/timeseries/model/Config.java-e new file mode 100644 index 000000000..461550b15 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/model/Config.java-e @@ -0,0 +1,575 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.model; + +import static org.opensearch.timeseries.constant.CommonMessages.INVALID_CHAR_IN_RESULT_INDEX_NAME; + +import java.io.IOException; +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.util.Strings; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.timeseries.annotation.Generated; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.dataprocessor.FixedValueImputer; +import org.opensearch.timeseries.dataprocessor.ImputationMethod; +import org.opensearch.timeseries.dataprocessor.ImputationOption; +import org.opensearch.timeseries.dataprocessor.Imputer; +import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; +import org.opensearch.timeseries.dataprocessor.PreviousValueImputer; +import org.opensearch.timeseries.dataprocessor.ZeroImputer; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + +import com.google.common.base.Objects; +import com.google.common.collect.ImmutableList; + +public abstract class Config implements Writeable, ToXContentObject { + private static final Logger logger = LogManager.getLogger(Config.class); + + public static final int MAX_RESULT_INDEX_NAME_SIZE = 255; + // OS doesn’t allow uppercase: https://tinyurl.com/yse2xdbx + public static final String RESULT_INDEX_NAME_PATTERN = "[a-z0-9_-]+"; + + public static final String NO_ID = ""; + public static final String TIMEOUT = "timeout"; + public static final String GENERAL_SETTINGS = "general_settings"; + public static final String AGGREGATION = "aggregation_issue"; + + // field in JSON representation + public static final String NAME_FIELD = "name"; + public static final String DESCRIPTION_FIELD = "description"; + public static final String TIMEFIELD_FIELD = "time_field"; + public static final String INDICES_FIELD = "indices"; + public static final String UI_METADATA_FIELD = "ui_metadata"; + public static final String FILTER_QUERY_FIELD = "filter_query"; + public static final String FEATURE_ATTRIBUTES_FIELD = "feature_attributes"; + public static final String WINDOW_DELAY_FIELD = "window_delay"; + public static final String SHINGLE_SIZE_FIELD = "shingle_size"; + public static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; + public static final String CATEGORY_FIELD = "category_field"; + public static final String USER_FIELD = "user"; + public static final String RESULT_INDEX_FIELD = "result_index"; + public static final String IMPUTATION_OPTION_FIELD = "imputation_option"; + + private static final Imputer zeroImputer; + private static final Imputer previousImputer; + private static final Imputer linearImputer; + private static final Imputer linearImputerIntegerSensitive; + + protected String id; + protected Long version; + protected String name; + protected String description; + protected String timeField; + protected List indices; + protected List featureAttributes; + protected QueryBuilder filterQuery; + protected TimeConfiguration interval; + protected TimeConfiguration windowDelay; + protected Integer shingleSize; + protected String customResultIndex; + protected Map uiMetadata; + protected Integer schemaVersion; + protected Instant lastUpdateTime; + protected List categoryFields; + protected User user; + protected ImputationOption imputationOption; + + // validation error + protected String errorMessage; + protected ValidationIssueType issueType; + + protected Imputer imputer; + + public static String INVALID_RESULT_INDEX_NAME_SIZE = "Result index name size must contains less than " + + MAX_RESULT_INDEX_NAME_SIZE + + " characters"; + + static { + zeroImputer = new ZeroImputer(); + previousImputer = new PreviousValueImputer(); + linearImputer = new LinearUniformImputer(false); + linearImputerIntegerSensitive = new LinearUniformImputer(true); + } + + protected Config( + String id, + Long version, + String name, + String description, + String timeField, + List indices, + List features, + QueryBuilder filterQuery, + TimeConfiguration windowDelay, + Integer shingleSize, + Map uiMetadata, + Integer schemaVersion, + Instant lastUpdateTime, + List categoryFields, + User user, + String resultIndex, + TimeConfiguration interval, + ImputationOption imputationOption + ) { + if (Strings.isBlank(name)) { + errorMessage = CommonMessages.EMPTY_NAME; + issueType = ValidationIssueType.NAME; + return; + } + if (Strings.isBlank(timeField)) { + errorMessage = CommonMessages.NULL_TIME_FIELD; + issueType = ValidationIssueType.TIMEFIELD_FIELD; + return; + } + if (indices == null || indices.isEmpty()) { + errorMessage = CommonMessages.EMPTY_INDICES; + issueType = ValidationIssueType.INDICES; + return; + } + + if (invalidShingleSizeRange(shingleSize)) { + errorMessage = "Shingle size must be a positive integer no larger than " + + TimeSeriesSettings.MAX_SHINGLE_SIZE + + ". Got " + + shingleSize; + issueType = ValidationIssueType.SHINGLE_SIZE_FIELD; + return; + } + + errorMessage = validateCustomResultIndex(resultIndex); + if (errorMessage != null) { + issueType = ValidationIssueType.RESULT_INDEX; + return; + } + + if (imputationOption != null + && imputationOption.getMethod() == ImputationMethod.FIXED_VALUES + && imputationOption.getDefaultFill().isEmpty()) { + issueType = ValidationIssueType.IMPUTATION; + errorMessage = "No given values for fixed value interpolation"; + return; + } + + this.id = id; + this.version = version; + this.name = name; + this.description = description; + this.timeField = timeField; + this.indices = indices; + this.featureAttributes = features == null ? ImmutableList.of() : ImmutableList.copyOf(features); + this.filterQuery = filterQuery; + this.interval = interval; + this.windowDelay = windowDelay; + this.shingleSize = getShingleSize(shingleSize); + this.uiMetadata = uiMetadata; + this.schemaVersion = schemaVersion; + this.lastUpdateTime = lastUpdateTime; + this.categoryFields = categoryFields; + this.user = user; + this.customResultIndex = Strings.trimToNull(resultIndex); + this.imputationOption = imputationOption; + this.imputer = createImputer(); + this.issueType = null; + this.errorMessage = null; + } + + public Config(StreamInput input) throws IOException { + id = input.readOptionalString(); + version = input.readOptionalLong(); + name = input.readString(); + description = input.readOptionalString(); + timeField = input.readString(); + indices = input.readStringList(); + featureAttributes = input.readList(Feature::new); + filterQuery = input.readNamedWriteable(QueryBuilder.class); + interval = IntervalTimeConfiguration.readFrom(input); + windowDelay = IntervalTimeConfiguration.readFrom(input); + shingleSize = input.readInt(); + schemaVersion = input.readInt(); + this.categoryFields = input.readOptionalStringList(); + lastUpdateTime = input.readInstant(); + if (input.readBoolean()) { + this.user = new User(input); + } else { + user = null; + } + if (input.readBoolean()) { + this.uiMetadata = input.readMap(); + } else { + this.uiMetadata = null; + } + customResultIndex = input.readOptionalString(); + if (input.readBoolean()) { + this.imputationOption = new ImputationOption(input); + } else { + this.imputationOption = null; + } + this.imputer = createImputer(); + } + + /* + * Implicit constructor that be called implicitly when a subtype + * needs to call like AnomalyDetector(StreamInput). Otherwise, + * we will have compiler error: + * "Implicit super constructor Config() is undefined. + * Must explicitly invoke another constructor". + */ + public Config() { + this.imputer = null; + } + + @Override + public void writeTo(StreamOutput output) throws IOException { + output.writeOptionalString(id); + output.writeOptionalLong(version); + output.writeString(name); + output.writeOptionalString(description); + output.writeString(timeField); + output.writeStringCollection(indices); + output.writeList(featureAttributes); + output.writeNamedWriteable(filterQuery); + interval.writeTo(output); + windowDelay.writeTo(output); + output.writeInt(shingleSize); + output.writeInt(schemaVersion); + output.writeOptionalStringCollection(categoryFields); + output.writeInstant(lastUpdateTime); + if (user != null) { + output.writeBoolean(true); // user exists + user.writeTo(output); + } else { + output.writeBoolean(false); // user does not exist + } + if (uiMetadata != null) { + output.writeBoolean(true); + output.writeMap(uiMetadata); + } else { + output.writeBoolean(false); + } + output.writeOptionalString(customResultIndex); + if (imputationOption != null) { + output.writeBoolean(true); + imputationOption.writeTo(output); + } else { + output.writeBoolean(false); + } + } + + /** + * If the given shingle size is null, return default; + * otherwise, return the given shingle size. + * + * @param customShingleSize Given shingle size + * @return Shingle size + */ + protected static Integer getShingleSize(Integer customShingleSize) { + return customShingleSize == null ? TimeSeriesSettings.DEFAULT_SHINGLE_SIZE : customShingleSize; + } + + public boolean invalidShingleSizeRange(Integer shingleSizeToTest) { + return shingleSizeToTest != null && (shingleSizeToTest < 1 || shingleSizeToTest > TimeSeriesSettings.MAX_SHINGLE_SIZE); + } + + /** + * + * @return either ValidationAspect.FORECASTER or ValidationAspect.DETECTOR + * depending on this is a forecaster or detector config. + */ + protected abstract ValidationAspect getConfigValidationAspect(); + + @Generated + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + Config config = (Config) o; + // a few fields not included: + // 1)didn't include uiMetadata since toXContent/parse will produce a map of map + // and cause the parsed one not equal to the original one. This can be confusing. + // 2)didn't include id, schemaVersion, and lastUpdateTime as we deemed equality based on contents. + // Including id fails tests like AnomalyDetectorExecutionInput.testParseAnomalyDetectorExecutionInput. + return Objects.equal(name, config.name) + && Objects.equal(description, config.description) + && Objects.equal(timeField, config.timeField) + && Objects.equal(indices, config.indices) + && Objects.equal(featureAttributes, config.featureAttributes) + && Objects.equal(filterQuery, config.filterQuery) + && Objects.equal(interval, config.interval) + && Objects.equal(windowDelay, config.windowDelay) + && Objects.equal(shingleSize, config.shingleSize) + && Objects.equal(categoryFields, config.categoryFields) + && Objects.equal(user, config.user) + && Objects.equal(customResultIndex, config.customResultIndex) + && Objects.equal(imputationOption, config.imputationOption); + } + + @Generated + @Override + public int hashCode() { + return Objects + .hashCode( + name, + description, + timeField, + indices, + featureAttributes, + filterQuery, + interval, + windowDelay, + shingleSize, + categoryFields, + schemaVersion, + user, + customResultIndex, + imputationOption + ); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder + .field(NAME_FIELD, name) + .field(DESCRIPTION_FIELD, description) + .field(TIMEFIELD_FIELD, timeField) + .field(INDICES_FIELD, indices.toArray()) + .field(FILTER_QUERY_FIELD, filterQuery) + .field(WINDOW_DELAY_FIELD, windowDelay) + .field(SHINGLE_SIZE_FIELD, shingleSize) + .field(CommonName.SCHEMA_VERSION_FIELD, schemaVersion) + .field(FEATURE_ATTRIBUTES_FIELD, featureAttributes.toArray()); + + if (uiMetadata != null && !uiMetadata.isEmpty()) { + builder.field(UI_METADATA_FIELD, uiMetadata); + } + if (lastUpdateTime != null) { + builder.field(LAST_UPDATE_TIME_FIELD, lastUpdateTime.toEpochMilli()); + } + if (categoryFields != null) { + builder.field(CATEGORY_FIELD, categoryFields.toArray()); + } + if (user != null) { + builder.field(USER_FIELD, user); + } + if (customResultIndex != null) { + builder.field(RESULT_INDEX_FIELD, customResultIndex); + } + if (imputationOption != null) { + builder.field(IMPUTATION_OPTION_FIELD, imputationOption); + } + return builder; + } + + public Long getVersion() { + return version; + } + + public String getName() { + return name; + } + + public String getDescription() { + return description; + } + + public String getTimeField() { + return timeField; + } + + public List getIndices() { + return indices; + } + + public List getFeatureAttributes() { + return featureAttributes; + } + + public QueryBuilder getFilterQuery() { + return filterQuery; + } + + /** + * Returns enabled feature ids in the same order in feature attributes. + * + * @return a list of filtered feature ids. + */ + public List getEnabledFeatureIds() { + return featureAttributes.stream().filter(Feature::getEnabled).map(Feature::getId).collect(Collectors.toList()); + } + + public List getEnabledFeatureNames() { + return featureAttributes.stream().filter(Feature::getEnabled).map(Feature::getName).collect(Collectors.toList()); + } + + public TimeConfiguration getInterval() { + return interval; + } + + public TimeConfiguration getWindowDelay() { + return windowDelay; + } + + public Integer getShingleSize() { + return shingleSize; + } + + public Map getUiMetadata() { + return uiMetadata; + } + + public Integer getSchemaVersion() { + return schemaVersion; + } + + public Instant getLastUpdateTime() { + return lastUpdateTime; + } + + public List getCategoryFields() { + return this.categoryFields; + } + + public String getId() { + return id; + } + + public long getIntervalInMilliseconds() { + return ((IntervalTimeConfiguration) getInterval()).toDuration().toMillis(); + } + + public long getIntervalInSeconds() { + return getIntervalInMilliseconds() / 1000; + } + + public long getIntervalInMinutes() { + return getIntervalInMilliseconds() / 1000 / 60; + } + + public Duration getIntervalDuration() { + return ((IntervalTimeConfiguration) getInterval()).toDuration(); + } + + public User getUser() { + return user; + } + + public void setUser(User user) { + this.user = user; + } + + public String getCustomResultIndex() { + return customResultIndex; + } + + public boolean isHighCardinality() { + return Config.isHC(getCategoryFields()); + } + + public boolean hasMultipleCategories() { + return categoryFields != null && categoryFields.size() > 1; + } + + public String validateCustomResultIndex(String resultIndex) { + if (resultIndex == null) { + return null; + } + if (resultIndex.length() > MAX_RESULT_INDEX_NAME_SIZE) { + return Config.INVALID_RESULT_INDEX_NAME_SIZE; + } + if (!resultIndex.matches(RESULT_INDEX_NAME_PATTERN)) { + return INVALID_CHAR_IN_RESULT_INDEX_NAME; + } + return null; + } + + public static boolean isHC(List categoryFields) { + return categoryFields != null && categoryFields.size() > 0; + } + + public ImputationOption getImputationOption() { + return imputationOption; + } + + public Imputer getImputer() { + if (imputer != null) { + return imputer; + } + imputer = createImputer(); + return imputer; + } + + protected Imputer createImputer() { + Imputer imputer = null; + + // default interpolator is using last known value + if (imputationOption == null) { + return previousImputer; + } + + switch (imputationOption.getMethod()) { + case ZERO: + imputer = zeroImputer; + break; + case FIXED_VALUES: + // we did validate default fill is not empty in the constructor + imputer = new FixedValueImputer(imputationOption.getDefaultFill().get()); + break; + case PREVIOUS: + imputer = previousImputer; + break; + case LINEAR: + if (imputationOption.isIntegerSentive()) { + imputer = linearImputerIntegerSensitive; + } else { + imputer = linearImputer; + } + break; + default: + logger.error("unsupported method: " + imputationOption.getMethod()); + imputer = new PreviousValueImputer(); + break; + } + return imputer; + } + + protected void checkAndThrowValidationErrors(ValidationAspect validationAspect) { + if (errorMessage != null && issueType != null) { + throw new ValidationException(errorMessage, issueType, validationAspect); + } else if (errorMessage != null || issueType != null) { + throw new TimeSeriesException(CommonMessages.FAIL_TO_VALIDATE); + } + } + + public static Config parseConfig(Class configClass, XContentParser parser) throws IOException { + if (configClass == AnomalyDetector.class) { + return AnomalyDetector.parse(parser); + } else if (configClass == Forecaster.class) { + return Forecaster.parse(parser); + } else { + throw new IllegalArgumentException("Unsupported config type. Supported config types are [AnomalyDetector, Forecaster]"); + } + } +} diff --git a/src/main/java/org/opensearch/timeseries/model/DataByFeatureId.java b/src/main/java/org/opensearch/timeseries/model/DataByFeatureId.java index fb44ffbb4..c74679214 100644 --- a/src/main/java/org/opensearch/timeseries/model/DataByFeatureId.java +++ b/src/main/java/org/opensearch/timeseries/model/DataByFeatureId.java @@ -5,13 +5,13 @@ package org.opensearch.timeseries.model; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; diff --git a/src/main/java/org/opensearch/timeseries/model/DataByFeatureId.java-e b/src/main/java/org/opensearch/timeseries/model/DataByFeatureId.java-e new file mode 100644 index 000000000..c74679214 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/model/DataByFeatureId.java-e @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import com.google.common.base.Objects; + +/** + * Data and its Id + * + */ +public class DataByFeatureId implements ToXContentObject, Writeable { + + public static final String FEATURE_ID_FIELD = "feature_id"; + public static final String DATA_FIELD = "data"; + + protected String featureId; + protected Double data; + + public DataByFeatureId(String featureId, Double data) { + this.featureId = featureId; + this.data = data; + } + + /* + * Used by the subclass that has its own way of initializing data like + * reading from StreamInput + */ + protected DataByFeatureId() {} + + public DataByFeatureId(StreamInput input) throws IOException { + this.featureId = input.readString(); + this.data = input.readDouble(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject().field(FEATURE_ID_FIELD, featureId).field(DATA_FIELD, data); + return xContentBuilder.endObject(); + } + + public static DataByFeatureId parse(XContentParser parser) throws IOException { + String featureId = null; + Double data = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case FEATURE_ID_FIELD: + featureId = parser.text(); + break; + case DATA_FIELD: + data = parser.doubleValue(); + break; + default: + // the unknown field and it's children should be ignored + parser.skipChildren(); + break; + } + } + return new DataByFeatureId(featureId, data); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + DataByFeatureId that = (DataByFeatureId) o; + return Objects.equal(getFeatureId(), that.getFeatureId()) && Objects.equal(getData(), that.getData()); + } + + @Override + public int hashCode() { + return Objects.hashCode(getFeatureId(), getData()); + } + + public String getFeatureId() { + return featureId; + } + + public Double getData() { + return data; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(featureId); + out.writeDouble(data); + } + +} diff --git a/src/main/java/org/opensearch/timeseries/model/DateRange.java b/src/main/java/org/opensearch/timeseries/model/DateRange.java index cd376c7a6..f6b99b8e5 100644 --- a/src/main/java/org/opensearch/timeseries/model/DateRange.java +++ b/src/main/java/org/opensearch/timeseries/model/DateRange.java @@ -11,15 +11,15 @@ package org.opensearch.timeseries.model; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; import java.time.Instant; import org.apache.commons.lang.builder.ToStringBuilder; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; diff --git a/src/main/java/org/opensearch/timeseries/model/DateRange.java-e b/src/main/java/org/opensearch/timeseries/model/DateRange.java-e new file mode 100644 index 000000000..f6b99b8e5 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/model/DateRange.java-e @@ -0,0 +1,131 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.time.Instant; + +import org.apache.commons.lang.builder.ToStringBuilder; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.annotation.Generated; +import org.opensearch.timeseries.util.ParseUtils; + +import com.google.common.base.Objects; + +public class DateRange implements ToXContentObject, Writeable { + + public static final String START_TIME_FIELD = "start_time"; + public static final String END_TIME_FIELD = "end_time"; + + private final Instant startTime; + private final Instant endTime; + + public DateRange(Instant startTime, Instant endTime) { + this.startTime = startTime; + this.endTime = endTime; + validate(); + } + + public DateRange(StreamInput in) throws IOException { + this.startTime = in.readInstant(); + this.endTime = in.readInstant(); + validate(); + } + + private void validate() { + if (startTime == null) { + throw new IllegalArgumentException("Detection data range's start time must not be null"); + } + if (endTime == null) { + throw new IllegalArgumentException("Detection data range's end time must not be null"); + } + if (startTime.isAfter(endTime)) { + throw new IllegalArgumentException("Detection data range's end time must be after start time"); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + xContentBuilder.field(START_TIME_FIELD, startTime.toEpochMilli()); + xContentBuilder.field(END_TIME_FIELD, endTime.toEpochMilli()); + return xContentBuilder.endObject(); + } + + public static DateRange parse(XContentParser parser) throws IOException { + Instant startTime = null; + Instant endTime = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case START_TIME_FIELD: + startTime = ParseUtils.toInstant(parser); + break; + case END_TIME_FIELD: + endTime = ParseUtils.toInstant(parser); + break; + default: + parser.skipChildren(); + break; + } + } + return new DateRange(startTime, endTime); + } + + @Generated + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + DateRange that = (DateRange) o; + return Objects.equal(getStartTime(), that.getStartTime()) && Objects.equal(getEndTime(), that.getEndTime()); + } + + @Generated + @Override + public int hashCode() { + return Objects.hashCode(getStartTime(), getEndTime()); + } + + @Generated + @Override + public String toString() { + return new ToStringBuilder(this).append("startTime", startTime).append("endTime", endTime).toString(); + } + + public Instant getStartTime() { + return startTime; + } + + public Instant getEndTime() { + return endTime; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeInstant(startTime); + out.writeInstant(endTime); + } +} diff --git a/src/main/java/org/opensearch/timeseries/model/Entity.java b/src/main/java/org/opensearch/timeseries/model/Entity.java index c1cf962dc..f05f5dc2a 100644 --- a/src/main/java/org/opensearch/timeseries/model/Entity.java +++ b/src/main/java/org/opensearch/timeseries/model/Entity.java @@ -11,7 +11,7 @@ package org.opensearch.timeseries.model; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; import java.io.InputStream; @@ -26,13 +26,13 @@ import org.apache.lucene.util.SetOnce; import org.opensearch.common.Numbers; -import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.hash.MurmurHash3; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/main/java/org/opensearch/timeseries/model/Entity.java-e b/src/main/java/org/opensearch/timeseries/model/Entity.java-e new file mode 100644 index 000000000..d9c3c9c8b --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/model/Entity.java-e @@ -0,0 +1,400 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.SortedMap; +import java.util.TreeMap; + +import org.apache.lucene.util.SetOnce; +import org.opensearch.common.Numbers; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.common.hash.MurmurHash3; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.xcontent.XContentParser.Token; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.timeseries.annotation.Generated; +import org.opensearch.timeseries.constant.CommonName; + +import com.google.common.base.Joiner; +import com.google.common.base.Objects; + +/** + * Categorical field name and its value + * + */ +public class Entity implements ToXContentObject, Writeable { + + private static final long RANDOM_SEED = 42; + private static final String MODEL_ID_INFIX = "_entity_"; + + public static final String ATTRIBUTE_NAME_FIELD = "name"; + public static final String ATTRIBUTE_VALUE_FIELD = "value"; + + // model id + private SetOnce modelId = new SetOnce<>(); + // a map from attribute name like "host" to its value like "server_1" + // Use SortedMap so that the attributes are ordered and we can derive the unique + // string representation used in the hash ring. + private final SortedMap attributes; + + /** + * Create an entity that has multiple attributes + * @param attrs what we parsed from query output as a map of attribute and its values. + * @return the created entity + */ + public static Entity createEntityByReordering(Map attrs) { + SortedMap sortedMap = new TreeMap<>(); + for (Map.Entry categoryValuePair : attrs.entrySet()) { + sortedMap.put(categoryValuePair.getKey(), categoryValuePair.getValue().toString()); + } + return new Entity(sortedMap); + } + + /** + * Create an entity that has only one attribute + * @param attributeName the attribute's name + * @param attributeVal the attribute's value + * @return the created entity + */ + public static Entity createSingleAttributeEntity(String attributeName, String attributeVal) { + SortedMap sortedMap = new TreeMap<>(); + sortedMap.put(attributeName, attributeVal); + return new Entity(sortedMap); + } + + /** + * Create an entity from ordered attributes based on attribute names + * @param attrs attribute map + * @return the created entity + */ + public static Entity createEntityFromOrderedMap(SortedMap attrs) { + return new Entity(attrs); + } + + private Entity(SortedMap orderedAttrs) { + this.attributes = orderedAttrs; + } + + public Entity(StreamInput input) throws IOException { + this.attributes = new TreeMap<>(input.readMap(StreamInput::readString, StreamInput::readString)); + } + + /** + * Formatter when serializing to json. Used in cases when saving anomaly result for HCAD. + * The order is Alphabetical sorting (the one used by JDK to compare Strings). + * Example: + * z0 + * z11 + * z2 + */ + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startArray(); + for (Map.Entry attr : attributes.entrySet()) { + builder.startObject().field(ATTRIBUTE_NAME_FIELD, attr.getKey()).field(ATTRIBUTE_VALUE_FIELD, attr.getValue()).endObject(); + } + builder.endArray(); + return builder; + } + + /** + * Return a map representing the entity, used in the stats API. + * + * A stats API broadcasts requests to all nodes and renders node responses using toXContent. + * + * For the local node, the stats API's calls toXContent on the node response directly. + * For remote node, the coordinating node gets a serialized content from + * ADStatsNodeResponse.writeTo, deserializes the content, and renders the result using toXContent. + * Since ADStatsNodeResponse.writeTo uses StreamOutput::writeGenericValue, we can only use + * a List<Map<String, String>> instead of the Entity object itself as + * StreamOutput::writeGenericValue only recognizes built-in types. + * + * This functions returns a map consistent with what toXContent returns. + * + * @return a map representing the entity + */ + public List> toStat() { + List> res = new ArrayList<>(attributes.size() * 2); + for (Map.Entry attr : attributes.entrySet()) { + Map elements = new TreeMap<>(); + elements.put(ATTRIBUTE_NAME_FIELD, attr.getKey()); + elements.put(ATTRIBUTE_VALUE_FIELD, attr.getValue()); + res.add(elements); + } + return res; + } + + public static Entity parse(XContentParser parser) throws IOException { + SortedMap entities = new TreeMap<>(); + String parsedValue = null; + String parsedName = null; + + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != Token.END_OBJECT) { + String fieldName = parser.currentName(); + // move to the field value + parser.nextToken(); + switch (fieldName) { + case ATTRIBUTE_NAME_FIELD: + parsedName = parser.text(); + break; + case ATTRIBUTE_VALUE_FIELD: + parsedValue = parser.text(); + break; + default: + break; + } + } + // reset every time I have seen a name-value pair. + if (parsedName != null && parsedValue != null) { + entities.put(parsedName, parsedValue); + parsedValue = null; + parsedName = null; + } + } + return new Entity(entities); + } + + @Generated + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + Entity that = (Entity) o; + return Objects.equal(attributes, that.attributes); + } + + @Generated + @Override + public int hashCode() { + return Objects.hashCode(attributes); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(attributes, StreamOutput::writeString, StreamOutput::writeString); + } + + /** + * Used to print Entity info and localizing a node in a hash ring. + * @return a normalized String representing the entity. + */ + @Override + public String toString() { + return normalizedAttributes(attributes); + } + + /** + * Return a string of the attributes in the ascending order of attribute names + * @return a normalized String corresponding to the Map. The string is + * deterministic (i.e., no matter in what order we insert values, + * the returned the string is the same). This is to ensure keys with the + * same content mapped to the same node in our hash ring. + * + */ + private static String normalizedAttributes(SortedMap attributes) { + return Joiner.on(",").withKeyValueSeparator("=").join(attributes); + } + + /** + * Create model Id out of config Id and the attribute name and value pairs + * + * HCAD v1 uses the categorical value as part of the model document Id, + * but OpenSearch's document Id can be at most 512 bytes. Categorical + * values are usually less than 256 characters but can grow to 32766 + * in theory. HCAD v1 skips an entity if the entity's name is more than + * 256 characters. We cannot do that in v2 as that can reject a lot of + * entities. To overcome the obstacle, we hash categorical values to a + * 128-bit string (like SHA-1 that git uses) and use the hash as part + * of the model document Id. + * + * We have choices regarding when to use the hash as part of a model + * document Id: for all HC detectors or an HC detector with multiple + * categorical fields. The challenge lies in providing backward + * compatibility by looking for a model checkpoint for an HC detector + * with one categorical field. If using hashes for all HC detectors, + * we need two get requests to ensure that a model checkpoint exists. + * One uses the document Id without a hash, while one uses the document + * Id with a hash. The dual get requests are ineffective. If limiting + * hashes to an HC detector with multiple categorical fields, there is + * no backward compatibility issue. However, the code will be branchy. + * One may wonder if backward compatibility can be ignored; indeed, + * the old checkpoints will be gone after a transition period during + * upgrading. During the transition period, the HC detector can + * experience unnecessary cold starts as if the detectors were just + * started. The checkpoint index size can double if every model has + * two model documents. The transition period can be three days since + * our checkpoint retention period is three days. + * + * There is no perfect solution. Considering that we can initialize one + * million models within 15 minutes in our performance test, we prefer + * to keep one and multiple categorical fields consistent and use hash + * only. This lifts the limitation that the categorical values cannot + * be more than 256 characters when there is one categorical field. + * Also, We will use hashes for new analyses like forecasting, regardless + * of the number of categorical fields. Using hashes always helps simplify + * our code base without worrying about whether the config is + * AnomalyDetector and when it is not. Thus, we prefer a hash-only solution + * for ease of use and maintainability. + * + * @param configId config Id + * @param attributes Attributes of an entity + * @return the model Id + */ + private static Optional getModelId(String configId, SortedMap attributes) { + if (attributes.isEmpty()) { + return Optional.empty(); + } else { + String normalizedFields = normalizedAttributes(attributes); + MurmurHash3.Hash128 hashFunc = MurmurHash3 + .hash128( + normalizedFields.getBytes(StandardCharsets.UTF_8), + 0, + normalizedFields.length(), + RANDOM_SEED, + new MurmurHash3.Hash128() + ); + // 16 bytes = 128 bits + byte[] bytes = new byte[16]; + System.arraycopy(Numbers.longToBytes(hashFunc.h1), 0, bytes, 0, 8); + System.arraycopy(Numbers.longToBytes(hashFunc.h2), 0, bytes, 8, 8); + // Some bytes like 10 in ascii is corrupted in some systems. Base64 ensures we use safe bytes: https://tinyurl.com/mxmrhmhf + return Optional.of(configId + MODEL_ID_INFIX + Base64.getUrlEncoder().withoutPadding().encodeToString(bytes)); + } + } + + /** + * Get the cached model Id if present. Or recompute one if missing. + * + * @param configId Id. Used as part of model Id. + * @return Model Id. Can be missing (e.g., the field value is too long for single-category detector) + */ + public Optional getModelId(String configId) { + if (modelId.get() == null) { + // computing model id is not cheap and the result is deterministic. We only do it once. + Optional computedModelId = Entity.getModelId(configId, attributes); + if (computedModelId.isPresent()) { + this.modelId.set(computedModelId.get()); + } else { + this.modelId.set(null); + } + } + return Optional.ofNullable(modelId.get()); + } + + public Map getAttributes() { + return attributes; + } + + /** + * Generate multi-term query filter like + * GET /company/_search + { + "query": { + "bool": { + "filter": [ + { + "term": { + "ip": "1.2.3.4" + } + }, + { + "term": { + "name.keyword": "Kaituo" + } + } + ] + } + } + } + * + *@return a list of term query builder + */ + public List getTermQueryBuilders() { + List res = new ArrayList<>(); + for (Map.Entry attribute : attributes.entrySet()) { + res.add(new TermQueryBuilder(attribute.getKey(), attribute.getValue())); + } + return res; + } + + public List getTermQueryBuilders(String pathPrefix) { + List res = new ArrayList<>(); + for (Map.Entry attribute : attributes.entrySet()) { + res.add(new TermQueryBuilder(pathPrefix + attribute.getKey(), attribute.getValue())); + } + return res; + } + + /** + * From json to Entity instance + * @param entityValue json array consisting attributes + * @return Entity instance + * @throws IOException when there is an deserialization issue. + */ + public static Entity fromJsonArray(Object entityValue) throws IOException { + XContentBuilder content = JsonXContent.contentBuilder(); + content.startObject(); + content.field(CommonName.ENTITY_KEY, entityValue); + content.endObject(); + + try ( + InputStream stream = BytesReference.bytes(content).streamInput(); + XContentParser parser = JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, stream) + ) { + // move to content.StartObject + parser.nextToken(); + // move to CommonName.ENTITY_KEY + parser.nextToken(); + // move to start of the array + parser.nextToken(); + return Entity.parse(parser); + } + } + + public static Optional fromJsonObject(XContentParser parser) throws IOException { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + if (false == CommonName.ENTITY_KEY.equals(parser.currentName())) { + // not an object with "entity" as the root key + return Optional.empty(); + } + // move to start of the array + parser.nextToken(); + return Optional.of(Entity.parse(parser)); + } + return Optional.empty(); + } +} diff --git a/src/main/java/org/opensearch/timeseries/model/Feature.java b/src/main/java/org/opensearch/timeseries/model/Feature.java index e8f58dde6..045a6b96b 100644 --- a/src/main/java/org/opensearch/timeseries/model/Feature.java +++ b/src/main/java/org/opensearch/timeseries/model/Feature.java @@ -11,15 +11,15 @@ package org.opensearch.timeseries.model; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; import org.apache.logging.log4j.util.Strings; import org.opensearch.common.UUIDs; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; diff --git a/src/main/java/org/opensearch/timeseries/model/Feature.java-e b/src/main/java/org/opensearch/timeseries/model/Feature.java-e new file mode 100644 index 000000000..045a6b96b --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/model/Feature.java-e @@ -0,0 +1,170 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + +import org.apache.logging.log4j.util.Strings; +import org.opensearch.common.UUIDs; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.timeseries.annotation.Generated; +import org.opensearch.timeseries.util.ParseUtils; + +import com.google.common.base.Objects; + +/** + * time series to analyze (a.k.a. feature) + */ +public class Feature implements Writeable, ToXContentObject { + + private static final String FEATURE_ID_FIELD = "feature_id"; + private static final String FEATURE_NAME_FIELD = "feature_name"; + private static final String FEATURE_ENABLED_FIELD = "feature_enabled"; + private static final String AGGREGATION_QUERY = "aggregation_query"; + + private final String id; + private final String name; + private final Boolean enabled; + private final AggregationBuilder aggregation; + + /** + * Constructor function. + * @param id feature id + * @param name feature name + * @param enabled feature enabled or not + * @param aggregation feature aggregation query + */ + public Feature(String id, String name, Boolean enabled, AggregationBuilder aggregation) { + if (Strings.isBlank(name)) { + throw new IllegalArgumentException("Feature name should be set"); + } + if (aggregation == null) { + throw new IllegalArgumentException("Feature aggregation query should be set"); + } + this.id = id; + this.name = name; + this.enabled = enabled; + this.aggregation = aggregation; + } + + public Feature(StreamInput input) throws IOException { + this.id = input.readString(); + this.name = input.readString(); + this.enabled = input.readBoolean(); + this.aggregation = input.readNamedWriteable(AggregationBuilder.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(this.id); + out.writeString(this.name); + out.writeBoolean(this.enabled); + out.writeNamedWriteable(aggregation); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder + .startObject() + .field(FEATURE_ID_FIELD, id) + .field(FEATURE_NAME_FIELD, name) + .field(FEATURE_ENABLED_FIELD, enabled) + .field(AGGREGATION_QUERY) + .startObject() + .value(aggregation) + .endObject(); + return xContentBuilder.endObject(); + } + + /** + * Parse raw json content into feature instance. + * + * @param parser json based content parser + * @return feature instance + * @throws IOException IOException if content can't be parsed correctly + */ + public static Feature parse(XContentParser parser) throws IOException { + String id = UUIDs.base64UUID(); + String name = null; + Boolean enabled = null; + AggregationBuilder aggregation = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + + parser.nextToken(); + switch (fieldName) { + case FEATURE_ID_FIELD: + id = parser.text(); + break; + case FEATURE_NAME_FIELD: + name = parser.text(); + break; + case FEATURE_ENABLED_FIELD: + enabled = parser.booleanValue(); + break; + case AGGREGATION_QUERY: + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + aggregation = ParseUtils.toAggregationBuilder(parser); + break; + default: + break; + } + } + return new Feature(id, name, enabled, aggregation); + } + + @Generated + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + Feature feature = (Feature) o; + return Objects.equal(getId(), feature.getId()) + && Objects.equal(getName(), feature.getName()) + && Objects.equal(getEnabled(), feature.getEnabled()) + && Objects.equal(getAggregation(), feature.getAggregation()); + } + + @Generated + @Override + public int hashCode() { + return Objects.hashCode(id, name, enabled); + } + + public String getId() { + return id; + } + + public String getName() { + return name; + } + + public Boolean getEnabled() { + return enabled; + } + + public AggregationBuilder getAggregation() { + return aggregation; + } +} diff --git a/src/main/java/org/opensearch/timeseries/model/FeatureData.java b/src/main/java/org/opensearch/timeseries/model/FeatureData.java index 23171f831..584dfbf4f 100644 --- a/src/main/java/org/opensearch/timeseries/model/FeatureData.java +++ b/src/main/java/org/opensearch/timeseries/model/FeatureData.java @@ -11,12 +11,12 @@ package org.opensearch.timeseries.model; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.timeseries.annotation.Generated; diff --git a/src/main/java/org/opensearch/timeseries/model/FeatureData.java-e b/src/main/java/org/opensearch/timeseries/model/FeatureData.java-e new file mode 100644 index 000000000..584dfbf4f --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/model/FeatureData.java-e @@ -0,0 +1,110 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.annotation.Generated; + +import com.google.common.base.Objects; + +/** + * Feature data used by RCF model. + */ +public class FeatureData extends DataByFeatureId { + + public static final String FEATURE_NAME_FIELD = "feature_name"; + + private final String featureName; + + public FeatureData(String featureId, String featureName, Double data) { + super(featureId, data); + this.featureName = featureName; + } + + public FeatureData(StreamInput input) throws IOException { + this.featureId = input.readString(); + this.featureName = input.readString(); + this.data = input.readDouble(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder + .startObject() + .field(FEATURE_ID_FIELD, featureId) + .field(FEATURE_NAME_FIELD, featureName) + .field(DATA_FIELD, data); + return xContentBuilder.endObject(); + } + + public static FeatureData parse(XContentParser parser) throws IOException { + String featureId = null; + Double data = null; + String parsedFeatureName = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case FEATURE_ID_FIELD: + featureId = parser.text(); + break; + case FEATURE_NAME_FIELD: + parsedFeatureName = parser.text(); + break; + case DATA_FIELD: + data = parser.doubleValue(); + break; + default: + break; + } + } + return new FeatureData(featureId, parsedFeatureName, data); + } + + @Generated + @Override + public boolean equals(Object o) { + if (super.equals(o)) { + FeatureData that = (FeatureData) o; + return Objects.equal(featureName, that.featureName); + } + return false; + } + + @Generated + @Override + public int hashCode() { + return Objects.hashCode(super.hashCode(), featureName); + } + + @Generated + public String getFeatureName() { + return featureName; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(featureId); + out.writeString(featureName); + out.writeDouble(data); + } +} diff --git a/src/main/java/org/opensearch/timeseries/model/IndexableResult.java b/src/main/java/org/opensearch/timeseries/model/IndexableResult.java index d6186ba63..7ccc58b59 100644 --- a/src/main/java/org/opensearch/timeseries/model/IndexableResult.java +++ b/src/main/java/org/opensearch/timeseries/model/IndexableResult.java @@ -18,10 +18,10 @@ import java.util.Optional; import org.apache.commons.lang.builder.ToStringBuilder; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.timeseries.annotation.Generated; diff --git a/src/main/java/org/opensearch/timeseries/model/IndexableResult.java-e b/src/main/java/org/opensearch/timeseries/model/IndexableResult.java-e new file mode 100644 index 000000000..75fa95ac6 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/model/IndexableResult.java-e @@ -0,0 +1,258 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.model; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import org.apache.commons.lang.builder.ToStringBuilder; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.timeseries.annotation.Generated; + +import com.google.common.base.Objects; + +public abstract class IndexableResult implements Writeable, ToXContentObject { + + protected final String configId; + protected final List featureData; + protected final Instant dataStartTime; + protected final Instant dataEndTime; + protected final Instant executionStartTime; + protected final Instant executionEndTime; + protected final String error; + protected final Optional optionalEntity; + protected User user; + protected final Integer schemaVersion; + /* + * model id for easy aggregations of entities. The front end needs to query + * for entities ordered by the descending/ascending order of feature values. + * After supporting multi-category fields, it is hard to write such queries + * since the entity information is stored in a nested object array. + * Also, the front end has all code/queries/ helper functions in place to + * rely on a single key per entity combo. Adding model id to forecast result + * to help the transition to multi-categorical field less painful. + */ + protected final String modelId; + protected final String entityId; + protected final String taskId; + + public IndexableResult( + String configId, + List featureData, + Instant dataStartTime, + Instant dataEndTime, + Instant executionStartTime, + Instant executionEndTime, + String error, + Optional entity, + User user, + Integer schemaVersion, + String modelId, + String taskId + ) { + this.configId = configId; + this.featureData = featureData; + this.dataStartTime = dataStartTime; + this.dataEndTime = dataEndTime; + this.executionStartTime = executionStartTime; + this.executionEndTime = executionEndTime; + this.error = error; + this.optionalEntity = entity; + this.user = user; + this.schemaVersion = schemaVersion; + this.modelId = modelId; + this.taskId = taskId; + this.entityId = getEntityId(entity, configId); + } + + public IndexableResult(StreamInput input) throws IOException { + this.configId = input.readString(); + int featureSize = input.readVInt(); + this.featureData = new ArrayList<>(featureSize); + for (int i = 0; i < featureSize; i++) { + featureData.add(new FeatureData(input)); + } + this.dataStartTime = input.readInstant(); + this.dataEndTime = input.readInstant(); + this.executionStartTime = input.readInstant(); + this.executionEndTime = input.readInstant(); + this.error = input.readOptionalString(); + if (input.readBoolean()) { + this.optionalEntity = Optional.of(new Entity(input)); + } else { + this.optionalEntity = Optional.empty(); + } + if (input.readBoolean()) { + this.user = new User(input); + } else { + user = null; + } + this.schemaVersion = input.readInt(); + this.modelId = input.readOptionalString(); + this.taskId = input.readOptionalString(); + this.entityId = input.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(configId); + out.writeVInt(featureData.size()); + for (FeatureData feature : featureData) { + feature.writeTo(out); + } + out.writeInstant(dataStartTime); + out.writeInstant(dataEndTime); + out.writeInstant(executionStartTime); + out.writeInstant(executionEndTime); + out.writeOptionalString(error); + if (optionalEntity.isPresent()) { + out.writeBoolean(true); + optionalEntity.get().writeTo(out); + } else { + out.writeBoolean(false); + } + if (user != null) { + out.writeBoolean(true); // user exists + user.writeTo(out); + } else { + out.writeBoolean(false); // user does not exist + } + out.writeInt(schemaVersion); + out.writeOptionalString(modelId); + out.writeOptionalString(taskId); + out.writeOptionalString(entityId); + } + + public String getConfigId() { + return configId; + } + + public List getFeatureData() { + return featureData; + } + + public Instant getDataStartTime() { + return dataStartTime; + } + + public Instant getDataEndTime() { + return dataEndTime; + } + + public Instant getExecutionStartTime() { + return executionStartTime; + } + + public Instant getExecutionEndTime() { + return executionEndTime; + } + + public String getError() { + return error; + } + + public Optional getEntity() { + return optionalEntity; + } + + public String getModelId() { + return modelId; + } + + public String getTaskId() { + return taskId; + } + + public String getEntityId() { + return entityId; + } + + /** + * entityId equals to model Id. It is hard to explain to users what + * modelId is. entityId is more user friendly. + * @param entity Entity info + * @param configId config id + * @return entity id + */ + public static String getEntityId(Optional entity, String configId) { + return entity.flatMap(e -> e.getModelId(configId)).orElse(null); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + IndexableResult that = (IndexableResult) o; + return Objects.equal(configId, that.configId) + && Objects.equal(taskId, that.taskId) + && Objects.equal(featureData, that.featureData) + && Objects.equal(dataStartTime, that.dataStartTime) + && Objects.equal(dataEndTime, that.dataEndTime) + && Objects.equal(executionStartTime, that.executionStartTime) + && Objects.equal(executionEndTime, that.executionEndTime) + && Objects.equal(error, that.error) + && Objects.equal(optionalEntity, that.optionalEntity) + && Objects.equal(modelId, that.modelId) + && Objects.equal(entityId, that.entityId); + } + + @Generated + @Override + public int hashCode() { + return Objects + .hashCode( + configId, + taskId, + featureData, + dataStartTime, + dataEndTime, + executionStartTime, + executionEndTime, + error, + optionalEntity, + modelId, + entityId + ); + } + + @Override + public String toString() { + return new ToStringBuilder(this) + .append("configId", configId) + .append("taskId", taskId) + .append("featureData", featureData) + .append("dataStartTime", dataStartTime) + .append("dataEndTime", dataEndTime) + .append("executionStartTime", executionStartTime) + .append("executionEndTime", executionEndTime) + .append("error", error) + .append("entity", optionalEntity) + .append("modelId", modelId) + .append("entityId", entityId) + .toString(); + } + + /** + * Used to throw away requests when index pressure is high. + * @return whether the result is high priority. + */ + public abstract boolean isHighPriority(); +} diff --git a/src/main/java/org/opensearch/timeseries/model/IntervalTimeConfiguration.java b/src/main/java/org/opensearch/timeseries/model/IntervalTimeConfiguration.java index 78df95467..eaa6301df 100644 --- a/src/main/java/org/opensearch/timeseries/model/IntervalTimeConfiguration.java +++ b/src/main/java/org/opensearch/timeseries/model/IntervalTimeConfiguration.java @@ -18,8 +18,8 @@ import java.util.Set; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.annotation.Generated; diff --git a/src/main/java/org/opensearch/timeseries/model/IntervalTimeConfiguration.java-e b/src/main/java/org/opensearch/timeseries/model/IntervalTimeConfiguration.java-e new file mode 100644 index 000000000..eaa6301df --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/model/IntervalTimeConfiguration.java-e @@ -0,0 +1,122 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.model; + +import java.io.IOException; +import java.time.Duration; +import java.time.temporal.ChronoUnit; +import java.util.Locale; +import java.util.Set; + +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.annotation.Generated; + +import com.google.common.base.Objects; +import com.google.common.collect.ImmutableSet; + +public class IntervalTimeConfiguration extends TimeConfiguration { + + private long interval; + private ChronoUnit unit; + + private static final Set SUPPORTED_UNITS = ImmutableSet.of(ChronoUnit.MINUTES, ChronoUnit.SECONDS); + + /** + * Constructor function. + * + * @param interval interval period value + * @param unit time unit + */ + public IntervalTimeConfiguration(long interval, ChronoUnit unit) { + if (interval < 0) { + throw new IllegalArgumentException( + String + .format( + Locale.ROOT, + "Interval %s %s", + interval, + org.opensearch.timeseries.constant.CommonMessages.NEGATIVE_TIME_CONFIGURATION + ) + ); + } + if (!SUPPORTED_UNITS.contains(unit)) { + throw new IllegalArgumentException(String.format(Locale.ROOT, ADCommonMessages.INVALID_TIME_CONFIGURATION_UNITS, unit)); + } + this.interval = interval; + this.unit = unit; + } + + public IntervalTimeConfiguration(StreamInput input) throws IOException { + this.interval = input.readLong(); + this.unit = input.readEnum(ChronoUnit.class); + } + + public static IntervalTimeConfiguration readFrom(StreamInput input) throws IOException { + return new IntervalTimeConfiguration(input); + } + + public static long getIntervalInMinute(IntervalTimeConfiguration interval) { + if (interval.getUnit() == ChronoUnit.SECONDS) { + return interval.getInterval() / 60; + } + return interval.getInterval(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeLong(this.interval); + out.writeEnum(this.unit); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject().startObject(PERIOD_FIELD).field(INTERVAL_FIELD, interval).field(UNIT_FIELD, unit).endObject().endObject(); + return builder; + } + + @Generated + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + IntervalTimeConfiguration that = (IntervalTimeConfiguration) o; + return getInterval() == that.getInterval() && getUnit() == that.getUnit(); + } + + @Generated + @Override + public int hashCode() { + return Objects.hashCode(interval, unit); + } + + public long getInterval() { + return interval; + } + + public ChronoUnit getUnit() { + return unit; + } + + /** + * Returns the duration of the interval. + * + * @return the duration of the interval + */ + public Duration toDuration() { + return Duration.of(interval, unit); + } +} diff --git a/src/main/java/org/opensearch/timeseries/model/TimeConfiguration.java b/src/main/java/org/opensearch/timeseries/model/TimeConfiguration.java index 28e83333d..d370e524e 100644 --- a/src/main/java/org/opensearch/timeseries/model/TimeConfiguration.java +++ b/src/main/java/org/opensearch/timeseries/model/TimeConfiguration.java @@ -11,13 +11,13 @@ package org.opensearch.timeseries.model; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; import java.time.temporal.ChronoUnit; import java.util.Locale; -import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentParser; diff --git a/src/main/java/org/opensearch/timeseries/model/TimeConfiguration.java-e b/src/main/java/org/opensearch/timeseries/model/TimeConfiguration.java-e new file mode 100644 index 000000000..d370e524e --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/model/TimeConfiguration.java-e @@ -0,0 +1,75 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.time.temporal.ChronoUnit; +import java.util.Locale; + +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentParser; + +/** + * TimeConfiguration represents the time configuration for a job which runs regularly. + */ +public abstract class TimeConfiguration implements Writeable, ToXContentObject { + + public static final String PERIOD_FIELD = "period"; + public static final String INTERVAL_FIELD = "interval"; + public static final String UNIT_FIELD = "unit"; + + /** + * Parse raw json content into schedule instance. + * + * @param parser json based content parser + * @return schedule instance + * @throws IOException IOException if content can't be parsed correctly + */ + public static TimeConfiguration parse(XContentParser parser) throws IOException { + long interval = 0; + ChronoUnit unit = null; + String scheduleType = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + scheduleType = parser.currentName(); + parser.nextToken(); + switch (scheduleType) { + case PERIOD_FIELD: + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String periodFieldName = parser.currentName(); + parser.nextToken(); + switch (periodFieldName) { + case INTERVAL_FIELD: + interval = parser.longValue(); + break; + case UNIT_FIELD: + unit = ChronoUnit.valueOf(parser.text().toUpperCase(Locale.ROOT)); + break; + default: + break; + } + } + break; + default: + break; + } + } + if (PERIOD_FIELD.equals(scheduleType)) { + return new IntervalTimeConfiguration(interval, unit); + } + throw new IllegalArgumentException("Find no schedule definition"); + } +} diff --git a/src/main/java/org/opensearch/timeseries/model/ValidationAspect.java-e b/src/main/java/org/opensearch/timeseries/model/ValidationAspect.java-e new file mode 100644 index 000000000..95fbf2217 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/model/ValidationAspect.java-e @@ -0,0 +1,69 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.model; + +import java.util.Collection; +import java.util.Set; + +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.timeseries.Name; +import org.opensearch.timeseries.constant.CommonName; + +/** + * Validation Aspect enum. There two types of validation types for validation API, + * these correlate to two the possible params passed to validate API. + *
    + *
  • DETECTOR: + * All the following validation checks that will be executed will be + * based on detector configuration settings. If any validation checks fail the AD Creation + * process will be blocked and the user will be indicated what fields caused the failure. + *
+ */ +public enum ValidationAspect implements Name { + DETECTOR(ADCommonName.DETECTOR_ASPECT), + MODEL(CommonName.MODEL_ASPECT), + FORECASTER(ForecastCommonName.FORECASTER_ASPECT); + + private String name; + + ValidationAspect(String name) { + this.name = name; + } + + /** + * Get validation aspect + * + * @return name + */ + @Override + public String getName() { + return name; + } + + public static ValidationAspect getName(String name) { + switch (name) { + case ADCommonName.DETECTOR_ASPECT: + return DETECTOR; + case CommonName.MODEL_ASPECT: + return MODEL; + case ForecastCommonName.FORECASTER_ASPECT: + return FORECASTER; + default: + throw new IllegalArgumentException("Unsupported validation aspects"); + } + } + + public static Set getNames(Collection names) { + return Name.getNameFromCollection(names, ValidationAspect::getName); + } +} diff --git a/src/main/java/org/opensearch/timeseries/model/ValidationIssueType.java-e b/src/main/java/org/opensearch/timeseries/model/ValidationIssueType.java-e new file mode 100644 index 000000000..01913a9c6 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/model/ValidationIssueType.java-e @@ -0,0 +1,52 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.model; + +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.timeseries.Name; + +public enum ValidationIssueType implements Name { + NAME(Config.NAME_FIELD), + TIMEFIELD_FIELD(Config.TIMEFIELD_FIELD), + SHINGLE_SIZE_FIELD(Config.SHINGLE_SIZE_FIELD), + INDICES(Config.INDICES_FIELD), + FEATURE_ATTRIBUTES(Config.FEATURE_ATTRIBUTES_FIELD), + CATEGORY(Config.CATEGORY_FIELD), + FILTER_QUERY(Config.FILTER_QUERY_FIELD), + WINDOW_DELAY(Config.WINDOW_DELAY_FIELD), + GENERAL_SETTINGS(Config.GENERAL_SETTINGS), + RESULT_INDEX(Config.RESULT_INDEX_FIELD), + TIMEOUT(Config.TIMEOUT), + AGGREGATION(Config.AGGREGATION), // this is a unique case where aggregation failed due to an issue in core but + // don't want to throw exception + IMPUTATION(Config.IMPUTATION_OPTION_FIELD), + DETECTION_INTERVAL(AnomalyDetector.DETECTION_INTERVAL_FIELD), + FORECAST_INTERVAL(Forecaster.FORECAST_INTERVAL_FIELD), + HORIZON_SIZE(Forecaster.HORIZON_FIELD); + + private String name; + + ValidationIssueType(String name) { + this.name = name; + } + + /** + * Get validation type + * + * @return name + */ + @Override + public String getName() { + return name; + } +} diff --git a/src/main/java/org/opensearch/timeseries/settings/DynamicNumericSetting.java-e b/src/main/java/org/opensearch/timeseries/settings/DynamicNumericSetting.java-e new file mode 100644 index 000000000..c9fe72a83 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/settings/DynamicNumericSetting.java-e @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.settings; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; + +/** + * A container serving dynamic numeric setting. The caller does not have to call + * ClusterSettings.addSettingsUpdateConsumer and can access the most-up-to-date + * value using the singleton instance. This is convenient for a setting that's + * accessed by various places or it is not possible to install ClusterSettings.addSettingsUpdateConsumer + * as the enclosing instances are not singleton (i.e. deleted after use). + * + */ +public abstract class DynamicNumericSetting { + private static Logger logger = LogManager.getLogger(DynamicNumericSetting.class); + + private ClusterService clusterService; + /** Latest setting value for each registered key. Thread-safe is required. */ + private final Map latestSettings = new ConcurrentHashMap<>(); + + private final Map> settings; + + protected DynamicNumericSetting(Map> settings) { + this.settings = settings; + } + + private void setSettingsUpdateConsumers() { + for (Setting setting : settings.values()) { + clusterService.getClusterSettings().addSettingsUpdateConsumer(setting, newVal -> { + logger.info("[AD] The value of setting [{}] changed to [{}]", setting.getKey(), newVal); + latestSettings.put(setting.getKey(), newVal); + }); + } + } + + public void init(ClusterService clusterService) { + this.clusterService = clusterService; + setSettingsUpdateConsumers(); + } + + /** + * Get setting value by key. Return default value if not configured explicitly. + * + * @param key setting key. + * @param Setting type + * @return T setting value or default + */ + @SuppressWarnings("unchecked") + public T getSettingValue(String key) { + return (T) latestSettings.getOrDefault(key, getSetting(key).getDefault(Settings.EMPTY)); + } + + /** + * Override existing value. + * @param key Key + * @param newVal New value + */ + public void setSettingValue(String key, Object newVal) { + latestSettings.put(key, newVal); + } + + private Setting getSetting(String key) { + if (settings.containsKey(key)) { + return settings.get(key); + } + throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); + } + + public List> getSettings() { + return new ArrayList<>(settings.values()); + } +} diff --git a/src/main/java/org/opensearch/timeseries/settings/TimeSeriesSettings.java-e b/src/main/java/org/opensearch/timeseries/settings/TimeSeriesSettings.java-e new file mode 100644 index 000000000..a9aebff53 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/settings/TimeSeriesSettings.java-e @@ -0,0 +1,191 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.settings; + +import java.time.Duration; + +import org.opensearch.common.settings.Setting; +import org.opensearch.common.unit.TimeValue; + +public class TimeSeriesSettings { + + // ====================================== + // Model parameters + // ====================================== + public static final int DEFAULT_SHINGLE_SIZE = 8; + + // max shingle size we have seen from external users + // the larger shingle size, the harder to fill in a complete shingle + public static final int MAX_SHINGLE_SIZE = 60; + + public static final String CONFIG_INDEX_MAPPING_FILE = "mappings/anomaly-detectors.json"; + + public static final String JOBS_INDEX_MAPPING_FILE = "mappings/anomaly-detector-jobs.json"; + + // 100,000 insertions costs roughly 1KB. + public static final int DOOR_KEEPER_FOR_COLD_STARTER_MAX_INSERTION = 100_000; + + public static final double DOOR_KEEPER_FALSE_POSITIVE_RATE = 0.01; + + // clean up door keeper every 60 intervals + public static final int DOOR_KEEPER_MAINTENANCE_FREQ = 60; + + // 1 million insertion costs roughly 1 MB. + public static final int DOOR_KEEPER_FOR_CACHE_MAX_INSERTION = 1_000_000; + + // for a real-time operation, we trade off speed for memory as real time opearation + // only has to do one update/scoring per interval + public static final double REAL_TIME_BOUNDING_BOX_CACHE_RATIO = 0; + + // ====================================== + // Historical analysis + // ====================================== + public static final int MAX_BATCH_TASK_PIECE_SIZE = 10_000; + + // within an interval, how many percents are used to process requests. + // 1.0 means we use all of the detection interval to process requests. + // to ensure we don't block next interval, it is better to set it less than 1.0. + public static final float INTERVAL_RATIO_FOR_REQUESTS = 0.9f; + + public static final Duration HOURLY_MAINTENANCE = Duration.ofHours(1); + + // ====================================== + // Checkpoint setting + // ====================================== + // we won't accept a checkpoint larger than 30MB. Or we risk OOM. + // For reference, in RCF 1.0, the checkpoint of a RCF with 50 trees, 10 dimensions, + // 256 samples is of 3.2MB. + // In compact rcf, the same RCF is of 163KB. + // Since we allow at most 5 features, and the default shingle size is 8 and default + // tree number size is 100, we can have at most 25.6 MB in RCF 1.0. + // It is possible that cx increases the max features or shingle size, but we don't want + // to risk OOM for the flexibility. + public static final int MAX_CHECKPOINT_BYTES = 30_000_000; + + // Sets the cap on the number of buffer that can be allocated by the rcf deserialization + // buffer pool. Each buffer is of 512 bytes. Memory occupied by 20 buffers is 10.24 KB. + public static final int MAX_TOTAL_RCF_SERIALIZATION_BUFFERS = 20; + + // the size of the buffer used for rcf deserialization + public static final int SERIALIZATION_BUFFER_BYTES = 512; + + // ====================================== + // rate-limiting queue parameters + // ====================================== + /** + * CheckpointWriteRequest consists of IndexRequest (200 KB), and QueuedRequest + * fields (148 bytes, read comments of ENTITY_REQUEST_SIZE_CONSTANT). + * The total is roughly 200 KB per request. + * + * We don't want the total size exceeds 1% of the heap. + * We should have at most 1% heap / 200KB = heap / 20,000,000 + * For t3.small, 1% heap is of 10MB. The queue's size is up to + * 10^ 7 / 2.0 * 10^5 = 50 + */ + public static int CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES = 200_000; + + /** + * ResultWriteRequest consists of index request (roughly 1KB), and QueuedRequest + * fields (148 bytes, read comments of ENTITY_REQUEST_SIZE_CONSTANT). + * Plus Java object size (12 bytes), we have roughly 1160 bytes per request + * + * We don't want the total size exceeds 1% of the heap. + * We should have at most 1% heap / 1148 = heap / 116,000 + * For t3.small, 1% heap is of 10MB. The queue's size is up to + * 10^ 7 / 1160 = 8621 + */ + public static int RESULT_WRITE_QUEUE_SIZE_IN_BYTES = 1160; + + /** + * FeatureRequest has entityName (# category fields * 256, the recommended limit + * of a keyword field length), model Id (roughly 256 bytes), and QueuedRequest + * fields including config Id(roughly 128 bytes), dataStartTimeMillis (long, + * 8 bytes), and currentFeature (16 bytes, assume two features on average). + * Plus Java object size (12 bytes), we have roughly 932 bytes per request + * assuming we have 2 categorical fields (plan to support 2 categorical fields now). + * We don't want the total size exceeds 0.1% of the heap. + * We can have at most 0.1% heap / 932 = heap / 932,000. + * For t3.small, 0.1% heap is of 1MB. The queue's size is up to + * 10^ 6 / 932 = 1072 + */ + public static int FEATURE_REQUEST_SIZE_IN_BYTES = 932; + + /** + * CheckpointMaintainRequest has model Id (roughly 256 bytes), and QueuedRequest + * fields including detector Id(roughly 128 bytes), expirationEpochMs (long, + * 8 bytes), and priority (12 bytes). + * Plus Java object size (12 bytes), we have roughly 416 bytes per request. + * We don't want the total size exceeds 0.1% of the heap. + * We can have at most 0.1% heap / 416 = heap / 416,000. + * For t3.small, 0.1% heap is of 1MB. The queue's size is up to + * 10^ 6 / 416 = 2403 + */ + public static int CHECKPOINT_MAINTAIN_REQUEST_SIZE_IN_BYTES = 416; + + public static final float MAX_QUEUED_TASKS_RATIO = 0.5f; + + public static final float MEDIUM_SEGMENT_PRUNE_RATIO = 0.1f; + + public static final float LOW_SEGMENT_PRUNE_RATIO = 0.3f; + + // expensive maintenance (e.g., queue maintenance) with 1/10000 probability + public static final int MAINTENANCE_FREQ_CONSTANT = 10000; + + public static final Duration QUEUE_MAINTENANCE = Duration.ofMinutes(10); + + // ====================================== + // ML parameters + // ====================================== + // RCF + public static final int NUM_SAMPLES_PER_TREE = 256; + + public static final int NUM_TREES = 30; + + public static final double TIME_DECAY = 0.0001; + + // If we have 32 + shingleSize (hopefully recent) values, RCF can get up and running. It will be noisy — + // there is a reason that default size is 256 (+ shingle size), but it may be more useful for people to + /// start seeing some results. + public static final int NUM_MIN_SAMPLES = 32; + + // for a batch operation, we want all of the bounding box in-place for speed + public static final double BATCH_BOUNDING_BOX_CACHE_RATIO = 1; + + // ====================================== + // Cold start setting + // ====================================== + public static int MAX_COLD_START_ROUNDS = 2; + + // Thresholding + public static final double THRESHOLD_MIN_PVALUE = 0.995; + + // ====================================== + // Cold start setting + // ====================================== + public static final Setting MAX_RETRY_FOR_UNRESPONSIVE_NODE = Setting + .intSetting("plugins.timeseries.max_retry_for_unresponsive_node", 5, 0, Setting.Property.NodeScope, Setting.Property.Dynamic); + + public static final Setting BACKOFF_MINUTES = Setting + .positiveTimeSetting( + "plugins.timeseries.backoff_minutes", + TimeValue.timeValueMinutes(15), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + public static final Setting COOLDOWN_MINUTES = Setting + .positiveTimeSetting( + "plugins.timeseries.cooldown_minutes", + TimeValue.timeValueMinutes(5), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + // ====================================== + // AD Index setting + // ====================================== + public static int MAX_UPDATE_RETRY_TIMES = 10_000; +} diff --git a/src/main/java/org/opensearch/timeseries/stats/StatNames.java-e b/src/main/java/org/opensearch/timeseries/stats/StatNames.java-e new file mode 100644 index 000000000..a72e3f1b0 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/stats/StatNames.java-e @@ -0,0 +1,70 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.stats; + +import java.util.HashSet; +import java.util.Set; + +/** + * Enum containing names of all external stats which will be returned in + * AD stats REST API. + */ +public enum StatNames { + AD_EXECUTE_REQUEST_COUNT("ad_execute_request_count"), + AD_EXECUTE_FAIL_COUNT("ad_execute_failure_count"), + AD_HC_EXECUTE_REQUEST_COUNT("ad_hc_execute_request_count"), + AD_HC_EXECUTE_FAIL_COUNT("ad_hc_execute_failure_count"), + DETECTOR_COUNT("detector_count"), + SINGLE_ENTITY_DETECTOR_COUNT("single_entity_detector_count"), + MULTI_ENTITY_DETECTOR_COUNT("multi_entity_detector_count"), + ANOMALY_DETECTORS_INDEX_STATUS("anomaly_detectors_index_status"), + ANOMALY_RESULTS_INDEX_STATUS("anomaly_results_index_status"), + MODELS_CHECKPOINT_INDEX_STATUS("models_checkpoint_index_status"), + ANOMALY_DETECTION_JOB_INDEX_STATUS("anomaly_detection_job_index_status"), + ANOMALY_DETECTION_STATE_STATUS("anomaly_detection_state_status"), + MODEL_INFORMATION("models"), + AD_EXECUTING_BATCH_TASK_COUNT("ad_executing_batch_task_count"), + AD_CANCELED_BATCH_TASK_COUNT("ad_canceled_batch_task_count"), + AD_TOTAL_BATCH_TASK_EXECUTION_COUNT("ad_total_batch_task_execution_count"), + AD_BATCH_TASK_FAILURE_COUNT("ad_batch_task_failure_count"), + MODEL_COUNT("model_count"), + MODEL_CORRUTPION_COUNT("model_corruption_count"); + + private String name; + + StatNames(String name) { + this.name = name; + } + + /** + * Get stat name + * + * @return name + */ + public String getName() { + return name; + } + + /** + * Get set of stat names + * + * @return set of stat names + */ + public static Set getNames() { + Set names = new HashSet<>(); + + for (StatNames statName : StatNames.values()) { + names.add(statName.getName()); + } + return names; + } +} diff --git a/src/main/java/org/opensearch/timeseries/util/DataUtil.java-e b/src/main/java/org/opensearch/timeseries/util/DataUtil.java-e new file mode 100644 index 000000000..4f417e4f7 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/util/DataUtil.java-e @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.util; + +import java.util.Arrays; + +public class DataUtil { + /** + * Removes leading rows in a 2D array that contain Double.NaN values. + * + * This method iterates over the rows of the provided 2D array. If a row is found + * where all elements are not Double.NaN, it removes this row and all rows before it + * from the array. The modified array, which may be smaller than the original, is then returned. + * + * Note: If all rows contain at least one Double.NaN, the method will return an empty array. + * + * @param arr The 2D array from which leading rows containing Double.NaN are to be removed. + * @return A possibly smaller 2D array with leading rows containing Double.NaN removed. + */ + public static double[][] ltrim(double[][] arr) { + int numRows = arr.length; + if (numRows == 0) { + return new double[0][0]; + } + + int numCols = arr[0].length; + int startIndex = numRows; // Initialized to numRows + for (int i = 0; i < numRows; i++) { + boolean hasNaN = false; + for (int j = 0; j < numCols; j++) { + if (Double.isNaN(arr[i][j])) { + hasNaN = true; + break; + } + } + if (!hasNaN) { + startIndex = i; + break; // Stop the loop as soon as a row without NaN is found + } + } + + return Arrays.copyOfRange(arr, startIndex, arr.length); + } + +} diff --git a/src/main/java/org/opensearch/timeseries/util/DiscoveryNodeFilterer.java-e b/src/main/java/org/opensearch/timeseries/util/DiscoveryNodeFilterer.java-e new file mode 100644 index 000000000..ca3ba4eba --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/util/DiscoveryNodeFilterer.java-e @@ -0,0 +1,98 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.util; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Predicate; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; + +/** + * Util class to filter unwanted node types + * + */ +public class DiscoveryNodeFilterer { + private static final Logger LOG = LogManager.getLogger(DiscoveryNodeFilterer.class); + private final ClusterService clusterService; + private final HotDataNodePredicate eligibleNodeFilter; + + public DiscoveryNodeFilterer(ClusterService clusterService) { + this.clusterService = clusterService; + eligibleNodeFilter = new HotDataNodePredicate(); + } + + /** + * Find nodes that are elibile to be used by us. For example, Ultrawarm + * introduces warm nodes into the ES cluster. Currently, we distribute + * model partitions to all data nodes in the cluster randomly, which + * could cause a model performance downgrade issue once warm nodes + * are throttled due to resource limitations. The PR excludes warm nodes + * to place model partitions. + * @return an array of eligible data nodes + */ + public DiscoveryNode[] getEligibleDataNodes() { + ClusterState state = this.clusterService.state(); + final List eligibleNodes = new ArrayList<>(); + for (DiscoveryNode node : state.nodes()) { + if (eligibleNodeFilter.test(node)) { + eligibleNodes.add(node); + } + } + return eligibleNodes.toArray(new DiscoveryNode[0]); + } + + public DiscoveryNode[] getAllNodes() { + ClusterState state = this.clusterService.state(); + final List nodes = new ArrayList<>(); + for (DiscoveryNode node : state.nodes()) { + nodes.add(node); + } + return nodes.toArray(new DiscoveryNode[0]); + } + + public boolean isEligibleDataNode(DiscoveryNode node) { + return eligibleNodeFilter.test(node); + } + + /** + * + * @return the number of eligible data nodes + */ + public int getNumberOfEligibleDataNodes() { + return getEligibleDataNodes().length; + } + + /** + * @param node a discovery node + * @return whether we should use this node for AD + */ + public boolean isEligibleNode(DiscoveryNode node) { + return eligibleNodeFilter.test(node); + } + + static class HotDataNodePredicate implements Predicate { + @Override + public boolean test(DiscoveryNode discoveryNode) { + return discoveryNode.isDataNode() + && discoveryNode + .getAttributes() + .getOrDefault(ADCommonName.BOX_TYPE_KEY, ADCommonName.HOT_BOX_TYPE) + .equals(ADCommonName.HOT_BOX_TYPE); + } + } +} diff --git a/src/main/java/org/opensearch/timeseries/util/ParseUtils.java b/src/main/java/org/opensearch/timeseries/util/ParseUtils.java index 0f8fcd724..ee73be777 100644 --- a/src/main/java/org/opensearch/timeseries/util/ParseUtils.java +++ b/src/main/java/org/opensearch/timeseries/util/ParseUtils.java @@ -14,7 +14,7 @@ import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_GET_USER_INFO; import static org.opensearch.ad.constant.ADCommonMessages.NO_PERMISSION_TO_ACCESS_DETECTOR; import static org.opensearch.ad.constant.ADCommonName.EPOCH_MILLIS_FORMAT; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.search.aggregations.AggregationBuilders.dateRange; import static org.opensearch.search.aggregations.AggregatorFactories.VALID_AGG_NAME; import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; @@ -45,11 +45,11 @@ import org.opensearch.ad.transport.GetAnomalyDetectorResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.ParsingException; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.ConfigConstants; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.ParsingException; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; diff --git a/src/main/java/org/opensearch/timeseries/util/ParseUtils.java-e b/src/main/java/org/opensearch/timeseries/util/ParseUtils.java-e new file mode 100644 index 000000000..649bb23cb --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/util/ParseUtils.java-e @@ -0,0 +1,758 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.util; + +import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_GET_USER_INFO; +import static org.opensearch.ad.constant.ADCommonMessages.NO_PERMISSION_TO_ACCESS_DETECTOR; +import static org.opensearch.ad.constant.ADCommonName.EPOCH_MILLIS_FORMAT; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.search.aggregations.AggregationBuilders.dateRange; +import static org.opensearch.search.aggregations.AggregatorFactories.VALID_AGG_NAME; +import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; +import static org.opensearch.timeseries.settings.TimeSeriesSettings.MAX_BATCH_TASK_PIECE_SIZE; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.function.Consumer; +import java.util.regex.Matcher; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.transport.GetAnomalyDetectorResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.common.ParsingException; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.NestedQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.query.TermsQueryBuilder; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.aggregations.BaseAggregationBuilder; +import org.opensearch.search.aggregations.PipelineAggregationBuilder; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregationBuilder; +import org.opensearch.search.aggregations.bucket.composite.CompositeValuesSourceBuilder; +import org.opensearch.search.aggregations.bucket.composite.DateHistogramValuesSourceBuilder; +import org.opensearch.search.aggregations.bucket.histogram.DateHistogramInterval; +import org.opensearch.search.aggregations.bucket.range.DateRangeAggregationBuilder; +import org.opensearch.search.aggregations.metrics.Max; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; + +import com.carrotsearch.hppc.DoubleArrayList; +import com.google.common.collect.ImmutableList; + +/** + * Parsing utility functions. + */ +public final class ParseUtils { + private static final Logger logger = LogManager.getLogger(ParseUtils.class); + + private ParseUtils() {} + + /** + * Parse content parser to {@link java.time.Instant}. + * + * @param parser json based content parser + * @return instance of {@link java.time.Instant} + * @throws IOException IOException if content can't be parsed correctly + */ + public static Instant toInstant(XContentParser parser) throws IOException { + if (parser.currentToken() == null || parser.currentToken() == XContentParser.Token.VALUE_NULL) { + return null; + } + if (parser.currentToken().isValue()) { + return Instant.ofEpochMilli(parser.longValue()); + } + return null; + } + + /** + * Parse content parser to {@link AggregationBuilder}. + * + * @param parser json based content parser + * @return instance of {@link AggregationBuilder} + * @throws IOException IOException if content can't be parsed correctly + */ + public static AggregationBuilder toAggregationBuilder(XContentParser parser) throws IOException { + AggregatorFactories.Builder parsed = AggregatorFactories.parseAggregators(parser); + return parsed.getAggregatorFactories().iterator().next(); + } + + /** + * Parse json String into {@link XContentParser}. + * + * @param content json string + * @param contentRegistry ES named content registry + * @return instance of {@link XContentParser} + * @throws IOException IOException if content can't be parsed correctly + */ + public static XContentParser parser(String content, NamedXContentRegistry contentRegistry) throws IOException { + XContentParser parser = XContentType.JSON.xContent().createParser(contentRegistry, LoggingDeprecationHandler.INSTANCE, content); + parser.nextToken(); + return parser; + } + + /** + * parse aggregation String into {@link AggregatorFactories.Builder}. + * + * @param aggQuery aggregation query string + * @param xContentRegistry ES named content registry + * @param aggName aggregation name, if set, will use it to replace original aggregation name + * @return instance of {@link AggregatorFactories.Builder} + * @throws IOException IOException if content can't be parsed correctly + */ + public static AggregatorFactories.Builder parseAggregators(String aggQuery, NamedXContentRegistry xContentRegistry, String aggName) + throws IOException { + XContentParser parser = parser(aggQuery, xContentRegistry); + return parseAggregators(parser, aggName); + } + + /** + * Parse content parser to {@link AggregatorFactories.Builder}. + * + * @param parser json based content parser + * @param aggName aggregation name, if set, will use it to replace original aggregation name + * @return instance of {@link AggregatorFactories.Builder} + * @throws IOException IOException if content can't be parsed correctly + */ + public static AggregatorFactories.Builder parseAggregators(XContentParser parser, String aggName) throws IOException { + return parseAggregators(parser, 0, aggName); + } + + /** + * Parse content parser to {@link AggregatorFactories.Builder}. + * + * @param parser json based content parser + * @param level aggregation level, the top level start from 0 + * @param aggName aggregation name, if set, will use it to replace original aggregation name + * @return instance of {@link AggregatorFactories.Builder} + * @throws IOException IOException if content can't be parsed correctly + */ + public static AggregatorFactories.Builder parseAggregators(XContentParser parser, int level, String aggName) throws IOException { + Matcher validAggMatcher = VALID_AGG_NAME.matcher(""); + AggregatorFactories.Builder factories = new AggregatorFactories.Builder(); + + XContentParser.Token token = null; + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token != XContentParser.Token.FIELD_NAME) { + throw new ParsingException( + parser.getTokenLocation(), + "Unexpected token " + token + " in [aggs]: aggregations definitions must start with the name of the aggregation." + ); + } + final String aggregationName = aggName == null ? parser.currentName() : aggName; + if (!validAggMatcher.reset(aggregationName).matches()) { + throw new ParsingException( + parser.getTokenLocation(), + "Invalid aggregation name [" + + aggregationName + + "]. Aggregation names must be alpha-numeric and can only contain '_' and '-'" + ); + } + + token = parser.nextToken(); + if (token != XContentParser.Token.START_OBJECT) { + throw new ParsingException( + parser.getTokenLocation(), + "Aggregation definition for [" + + aggregationName + + " starts with a [" + + token + + "], expected a [" + + XContentParser.Token.START_OBJECT + + "]." + ); + } + + BaseAggregationBuilder aggBuilder = null; + AggregatorFactories.Builder subFactories = null; + + Map metaData = null; + + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token != XContentParser.Token.FIELD_NAME) { + throw new ParsingException( + parser.getTokenLocation(), + "Expected [" + + XContentParser.Token.FIELD_NAME + + "] under a [" + + XContentParser.Token.START_OBJECT + + "], but got a [" + + token + + "] in [" + + aggregationName + + "]", + parser.getTokenLocation() + ); + } + final String fieldName = parser.currentName(); + + token = parser.nextToken(); + if (token == XContentParser.Token.START_OBJECT) { + switch (fieldName) { + case "meta": + metaData = parser.map(); + break; + case "aggregations": + case "aggs": + if (subFactories != null) { + throw new ParsingException( + parser.getTokenLocation(), + "Found two sub aggregation definitions under [" + aggregationName + "]" + ); + } + subFactories = parseAggregators(parser, level + 1, null); + break; + default: + if (aggBuilder != null) { + throw new ParsingException( + parser.getTokenLocation(), + "Found two aggregation type definitions in [" + + aggregationName + + "]: [" + + aggBuilder.getType() + + "] and [" + + fieldName + + "]" + ); + } + + aggBuilder = parser.namedObject(BaseAggregationBuilder.class, fieldName, aggregationName); + } + } else { + throw new ParsingException( + parser.getTokenLocation(), + "Expected [" + + XContentParser.Token.START_OBJECT + + "] under [" + + fieldName + + "], but got a [" + + token + + "] in [" + + aggregationName + + "]" + ); + } + } + + if (aggBuilder == null) { + throw new ParsingException( + parser.getTokenLocation(), + "Missing definition for aggregation [" + aggregationName + "]", + parser.getTokenLocation() + ); + } else { + if (metaData != null) { + aggBuilder.setMetadata(metaData); + } + + if (subFactories != null) { + aggBuilder.subAggregations(subFactories); + } + + if (aggBuilder instanceof AggregationBuilder) { + factories.addAggregator((AggregationBuilder) aggBuilder); + } else { + factories.addPipelineAggregator((PipelineAggregationBuilder) aggBuilder); + } + } + } + + return factories; + } + + public static SearchSourceBuilder generateInternalFeatureQuery( + AnomalyDetector detector, + long startTime, + long endTime, + NamedXContentRegistry xContentRegistry + ) throws IOException { + RangeQueryBuilder rangeQuery = new RangeQueryBuilder(detector.getTimeField()) + .from(startTime) + .to(endTime) + .format("epoch_millis") + .includeLower(true) + .includeUpper(false); + + BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().must(rangeQuery).must(detector.getFilterQuery()); + + SearchSourceBuilder internalSearchSourceBuilder = new SearchSourceBuilder().query(internalFilterQuery); + if (detector.getFeatureAttributes() != null) { + for (Feature feature : detector.getFeatureAttributes()) { + AggregatorFactories.Builder internalAgg = parseAggregators( + feature.getAggregation().toString(), + xContentRegistry, + feature.getId() + ); + internalSearchSourceBuilder.aggregation(internalAgg.getAggregatorFactories().iterator().next()); + } + } + + return internalSearchSourceBuilder; + } + + public static SearchSourceBuilder generatePreviewQuery( + AnomalyDetector detector, + List> ranges, + NamedXContentRegistry xContentRegistry + ) throws IOException { + + DateRangeAggregationBuilder dateRangeBuilder = dateRange("date_range").field(detector.getTimeField()).format("epoch_millis"); + for (Entry range : ranges) { + dateRangeBuilder.addRange(range.getKey(), range.getValue()); + } + + if (detector.getFeatureAttributes() != null) { + for (Feature feature : detector.getFeatureAttributes()) { + AggregatorFactories.Builder internalAgg = parseAggregators( + feature.getAggregation().toString(), + xContentRegistry, + feature.getId() + ); + dateRangeBuilder.subAggregation(internalAgg.getAggregatorFactories().iterator().next()); + } + } + + return new SearchSourceBuilder().query(detector.getFilterQuery()).size(0).aggregation(dateRangeBuilder); + } + + public static SearchSourceBuilder generateEntityColdStartQuery( + AnomalyDetector detector, + List> ranges, + Entity entity, + NamedXContentRegistry xContentRegistry + ) throws IOException { + + BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().filter(detector.getFilterQuery()); + + for (TermQueryBuilder term : entity.getTermQueryBuilders()) { + internalFilterQuery.filter(term); + } + + DateRangeAggregationBuilder dateRangeBuilder = dateRange("date_range").field(detector.getTimeField()).format("epoch_millis"); + for (Entry range : ranges) { + dateRangeBuilder.addRange(range.getKey(), range.getValue()); + } + + if (detector.getFeatureAttributes() != null) { + for (Feature feature : detector.getFeatureAttributes()) { + AggregatorFactories.Builder internalAgg = parseAggregators( + feature.getAggregation().toString(), + xContentRegistry, + feature.getId() + ); + dateRangeBuilder.subAggregation(internalAgg.getAggregatorFactories().iterator().next()); + } + } + + return new SearchSourceBuilder().query(internalFilterQuery).size(0).aggregation(dateRangeBuilder); + } + + /** + * Map feature data to its Id and name + * @param currentFeature Feature data + * @param detector Detector Config object + * @return a list of feature data with Id and name + */ + public static List getFeatureData(double[] currentFeature, AnomalyDetector detector) { + List featureIds = detector.getEnabledFeatureIds(); + List featureNames = detector.getEnabledFeatureNames(); + int featureLen = featureIds.size(); + List featureData = new ArrayList<>(); + for (int i = 0; i < featureLen; i++) { + featureData.add(new FeatureData(featureIds.get(i), featureNames.get(i), currentFeature[i])); + } + return featureData; + } + + public static SearchSourceBuilder addUserBackendRolesFilter(User user, SearchSourceBuilder searchSourceBuilder) { + if (user == null) { + return searchSourceBuilder; + } + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + String userFieldName = "user"; + String userBackendRoleFieldName = "user.backend_roles.keyword"; + List backendRoles = user.getBackendRoles() != null ? user.getBackendRoles() : ImmutableList.of(); + // For normal case, user should have backend roles. + TermsQueryBuilder userRolesFilterQuery = QueryBuilders.termsQuery(userBackendRoleFieldName, backendRoles); + NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(userFieldName, userRolesFilterQuery, ScoreMode.None); + boolQueryBuilder.must(nestedQueryBuilder); + QueryBuilder query = searchSourceBuilder.query(); + if (query == null) { + searchSourceBuilder.query(boolQueryBuilder); + } else if (query instanceof BoolQueryBuilder) { + ((BoolQueryBuilder) query).filter(boolQueryBuilder); + } else { + throw new TimeSeriesException("Search API does not support queries other than BoolQuery"); + } + return searchSourceBuilder; + } + + /** + * Generates a user string formed by the username, backend roles, roles and requested tenants separated by '|' + * (e.g., john||own_index,testrole|__user__, no backend role so you see two verticle line after john.). + * This is the user string format used internally in the OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT and may be + * parsed using User.parse(string). + * @param client Client containing user info. A public API request will fill in the user info in the thread context. + * @return parsed user object + */ + public static User getUserContext(Client client) { + String userStr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + logger.debug("Filtering result by " + userStr); + return User.parse(userStr); + } + + public static void resolveUserAndExecute( + User requestedUser, + String detectorId, + boolean filterByEnabled, + ActionListener listener, + Consumer function, + Client client, + ClusterService clusterService, + NamedXContentRegistry xContentRegistry + ) { + try { + if (requestedUser == null || detectorId == null) { + // requestedUser == null means security is disabled or user is superadmin. In this case we don't need to + // check if request user have access to the detector or not. + function.accept(null); + } else { + getDetector(requestedUser, detectorId, listener, function, client, clusterService, xContentRegistry, filterByEnabled); + } + } catch (Exception e) { + listener.onFailure(e); + } + } + + /** + * If filterByEnabled is true, get detector and check if the user has permissions to access the detector, + * then execute function; otherwise, get detector and execute function + * @param requestUser user from request + * @param detectorId detector id + * @param listener action listener + * @param function consumer function + * @param client client + * @param clusterService cluster service + * @param xContentRegistry XContent registry + * @param filterByBackendRole filter by backend role or not + */ + public static void getDetector( + User requestUser, + String detectorId, + ActionListener listener, + Consumer function, + Client client, + ClusterService clusterService, + NamedXContentRegistry xContentRegistry, + boolean filterByBackendRole + ) { + if (clusterService.state().metadata().indices().containsKey(CommonName.CONFIG_INDEX)) { + GetRequest request = new GetRequest(CommonName.CONFIG_INDEX).id(detectorId); + client + .get( + request, + ActionListener + .wrap( + response -> onGetAdResponse( + response, + requestUser, + detectorId, + listener, + function, + xContentRegistry, + filterByBackendRole + ), + exception -> { + logger.error("Failed to get anomaly detector: " + detectorId, exception); + listener.onFailure(exception); + } + ) + ); + } else { + listener.onFailure(new IndexNotFoundException(CommonName.CONFIG_INDEX)); + } + } + + public static void onGetAdResponse( + GetResponse response, + User requestUser, + String detectorId, + ActionListener listener, + Consumer function, + NamedXContentRegistry xContentRegistry, + boolean filterByBackendRole + ) { + if (response.isExists()) { + try ( + XContentParser parser = RestHandlerUtils.createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + AnomalyDetector detector = AnomalyDetector.parse(parser); + User resourceUser = detector.getUser(); + + if (!filterByBackendRole || checkUserPermissions(requestUser, resourceUser, detectorId) || isAdmin(requestUser)) { + function.accept(detector); + } else { + logger.debug("User: " + requestUser.getName() + " does not have permissions to access detector: " + detectorId); + listener.onFailure(new TimeSeriesException(NO_PERMISSION_TO_ACCESS_DETECTOR + detectorId)); + } + } catch (Exception e) { + listener.onFailure(new TimeSeriesException(FAIL_TO_GET_USER_INFO + detectorId)); + } + } else { + listener.onFailure(new ResourceNotFoundException(detectorId, FAIL_TO_FIND_CONFIG_MSG + detectorId)); + } + } + + /** + * 'all_access' role users are treated as admins. + * @param user of the current role + * @return boolean if the role is admin + */ + public static boolean isAdmin(User user) { + if (user == null) { + return false; + } + return user.getRoles().contains("all_access"); + } + + private static boolean checkUserPermissions(User requestedUser, User resourceUser, String detectorId) throws Exception { + if (resourceUser.getBackendRoles() == null || requestedUser.getBackendRoles() == null) { + return false; + } + // Check if requested user has backend role required to access the resource + for (String backendRole : requestedUser.getBackendRoles()) { + if (resourceUser.getBackendRoles().contains(backendRole)) { + logger + .debug( + "User: " + + requestedUser.getName() + + " has backend role: " + + backendRole + + " permissions to access detector: " + + detectorId + ); + return true; + } + } + return false; + } + + public static boolean checkFilterByBackendRoles(User requestedUser, ActionListener listener) { + if (requestedUser == null) { + return false; + } + if (requestedUser.getBackendRoles().isEmpty()) { + listener + .onFailure( + new TimeSeriesException( + "Filter by backend roles is enabled and User " + requestedUser.getName() + " does not have backend roles configured" + ) + ); + return false; + } + return true; + } + + /** + * Parse max timestamp aggregation named CommonName.AGG_NAME_MAX + * @param searchResponse Search response + * @return max timestamp + */ + public static Optional getLatestDataTime(SearchResponse searchResponse) { + return Optional + .ofNullable(searchResponse) + .map(SearchResponse::getAggregations) + .map(aggs -> aggs.asMap()) + .map(map -> (Max) map.get(CommonName.AGG_NAME_MAX_TIME)) + .map(agg -> (long) agg.getValue()); + } + + /** + * Generate batch query request for feature aggregation on given date range. + * + * @param detector anomaly detector + * @param entity entity + * @param startTime start time + * @param endTime end time + * @param xContentRegistry content registry + * @return search source builder + * @throws IOException throw IO exception if fail to parse feature aggregation + * @throws TimeSeriesException throw AD exception if no enabled feature + */ + public static SearchSourceBuilder batchFeatureQuery( + AnomalyDetector detector, + Entity entity, + long startTime, + long endTime, + NamedXContentRegistry xContentRegistry + ) throws IOException { + RangeQueryBuilder rangeQuery = new RangeQueryBuilder(detector.getTimeField()) + .from(startTime) + .to(endTime) + .format(EPOCH_MILLIS_FORMAT) + .includeLower(true) + .includeUpper(false); + + BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().must(rangeQuery).must(detector.getFilterQuery()); + + if (detector.isHighCardinality() && entity != null && entity.getAttributes().size() > 0) { + entity + .getAttributes() + .entrySet() + .forEach(attr -> { internalFilterQuery.filter(new TermQueryBuilder(attr.getKey(), attr.getValue())); }); + } + + long intervalSeconds = ((IntervalTimeConfiguration) detector.getInterval()).toDuration().getSeconds(); + + List> sources = new ArrayList<>(); + sources + .add( + new DateHistogramValuesSourceBuilder(CommonName.DATE_HISTOGRAM) + .field(detector.getTimeField()) + .fixedInterval(DateHistogramInterval.seconds((int) intervalSeconds)) + ); + + CompositeAggregationBuilder aggregationBuilder = new CompositeAggregationBuilder(CommonName.FEATURE_AGGS, sources) + .size(MAX_BATCH_TASK_PIECE_SIZE); + + if (detector.getEnabledFeatureIds().size() == 0) { + throw new TimeSeriesException("No enabled feature configured").countedInStats(false); + } + + for (Feature feature : detector.getFeatureAttributes()) { + if (feature.getEnabled()) { + AggregatorFactories.Builder internalAgg = parseAggregators( + feature.getAggregation().toString(), + xContentRegistry, + feature.getId() + ); + aggregationBuilder.subAggregation(internalAgg.getAggregatorFactories().iterator().next()); + } + } + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.aggregation(aggregationBuilder); + searchSourceBuilder.query(internalFilterQuery); + searchSourceBuilder.size(0); + + return searchSourceBuilder; + } + + public static boolean isNullOrEmpty(Collection collection) { + return collection == null || collection.size() == 0; + } + + public static boolean isNullOrEmpty(Map map) { + return map == null || map.size() == 0; + } + + /** + * Check if two lists of string equals or not without considering the order. + * If the list is null, will consider it equals to empty list. + * + * @param list1 first list + * @param list2 second list + * @return true if two list of string equals + */ + public static boolean listEqualsWithoutConsideringOrder(List list1, List list2) { + Set set1 = new HashSet<>(); + Set set2 = new HashSet<>(); + if (list1 != null) { + set1.addAll(list1); + } + if (list2 != null) { + set2.addAll(list2); + } + return Objects.equals(set1, set2); + } + + public static double[] parseDoubleArray(XContentParser parser) throws IOException { + DoubleArrayList oldValList = new DoubleArrayList(); + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + oldValList.add(parser.doubleValue()); + } + return oldValList.toArray(); + } + + public static List parseAggregationRequest(XContentParser parser) throws IOException { + List fieldNames = new ArrayList<>(); + XContentParser.Token token; + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + final String field = parser.currentName(); + switch (field) { + case "field": + parser.nextToken(); + fieldNames.add(parser.textOrNull()); + break; + default: + parser.skipChildren(); + break; + } + } + } + return fieldNames; + } + + public static List getFeatureFieldNames(AnomalyDetector detector, NamedXContentRegistry xContentRegistry) throws IOException { + List featureFields = new ArrayList<>(); + for (Feature feature : detector.getFeatureAttributes()) { + featureFields.add(getFieldNamesForFeature(feature, xContentRegistry).get(0)); + } + return featureFields; + } + + public static List getFieldNamesForFeature(Feature feature, NamedXContentRegistry xContentRegistry) throws IOException { + ParseUtils.parseAggregators(feature.getAggregation().toString(), xContentRegistry, feature.getId()); + XContentParser parser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, feature.getAggregation().toString()); + parser.nextToken(); + return ParseUtils.parseAggregationRequest(parser); + } + +} diff --git a/src/main/java/org/opensearch/timeseries/util/RestHandlerUtils.java b/src/main/java/org/opensearch/timeseries/util/RestHandlerUtils.java index 563eda9a8..73ef78aef 100644 --- a/src/main/java/org/opensearch/timeseries/util/RestHandlerUtils.java +++ b/src/main/java/org/opensearch/timeseries/util/RestHandlerUtils.java @@ -11,8 +11,8 @@ package org.opensearch.timeseries.util; -import static org.opensearch.rest.RestStatus.BAD_REQUEST; -import static org.opensearch.rest.RestStatus.INTERNAL_SERVER_ERROR; +import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import java.io.IOException; import java.util.HashSet; @@ -28,10 +28,11 @@ import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.common.Nullable; import org.opensearch.common.Strings; -import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentParser; @@ -39,7 +40,6 @@ import org.opensearch.indices.InvalidIndexNameException; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; -import org.opensearch.rest.RestStatus; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; diff --git a/src/main/java/org/opensearch/timeseries/util/RestHandlerUtils.java-e b/src/main/java/org/opensearch/timeseries/util/RestHandlerUtils.java-e new file mode 100644 index 000000000..9a1559428 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/util/RestHandlerUtils.java-e @@ -0,0 +1,250 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.util; + +import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; + +import java.io.IOException; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.apache.commons.lang.ArrayUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchPhaseExecutionException; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.common.Nullable; +import org.opensearch.common.Strings; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.indices.InvalidIndexNameException; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestRequest; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Feature; + +import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableMap; + +/** + * Utility functions for REST handlers. + */ +public final class RestHandlerUtils { + private static final Logger logger = LogManager.getLogger(RestHandlerUtils.class); + public static final String _ID = "_id"; + public static final String _VERSION = "_version"; + public static final String _SEQ_NO = "_seq_no"; + public static final String IF_SEQ_NO = "if_seq_no"; + public static final String _PRIMARY_TERM = "_primary_term"; + public static final String IF_PRIMARY_TERM = "if_primary_term"; + public static final String REFRESH = "refresh"; + public static final String DETECTOR_ID = "detectorID"; + public static final String RESULT_INDEX = "resultIndex"; + public static final String ANOMALY_DETECTOR = "anomaly_detector"; + public static final String ANOMALY_DETECTOR_JOB = "anomaly_detector_job"; + public static final String REALTIME_TASK = "realtime_detection_task"; + public static final String HISTORICAL_ANALYSIS_TASK = "historical_analysis_task"; + public static final String RUN = "_run"; + public static final String PREVIEW = "_preview"; + public static final String START_JOB = "_start"; + public static final String STOP_JOB = "_stop"; + public static final String PROFILE = "_profile"; + public static final String TYPE = "type"; + public static final String ENTITY = "entity"; + public static final String COUNT = "count"; + public static final String MATCH = "match"; + public static final String RESULTS = "results"; + public static final String TOP_ANOMALIES = "_topAnomalies"; + public static final String VALIDATE = "_validate"; + public static final ToXContent.MapParams XCONTENT_WITH_TYPE = new ToXContent.MapParams(ImmutableMap.of("with_type", "true")); + + public static final String OPENSEARCH_DASHBOARDS_USER_AGENT = "OpenSearch Dashboards"; + public static final String[] UI_METADATA_EXCLUDE = new String[] { Config.UI_METADATA_FIELD }; + + public static final String FORECASTER_ID = "forecasterID"; + public static final String FORECASTER = "forecaster"; + public static final String REST_STATUS = "rest_status"; + + private RestHandlerUtils() {} + + /** + * Checks to see if the request came from OpenSearch-Dashboards, if so we want to return the UI Metadata from the document. + * If the request came from the client then we exclude the UI Metadata from the search result. + * We also take into account the given `_source` field and respect the correct fields to be returned. + * @param request rest request + * @param searchSourceBuilder an instance of the searchSourceBuilder to fetch _source field + * @return instance of {@link org.opensearch.search.fetch.subphase.FetchSourceContext} + */ + public static FetchSourceContext getSourceContext(RestRequest request, SearchSourceBuilder searchSourceBuilder) { + String userAgent = coalesceToEmpty(request.header("User-Agent")); + + // If there is a _source given in request than we either add UI_Metadata to exclude or not depending on if request + // is from OpenSearch-Dashboards, if no _source field then we either exclude UI_metadata or return nothing at all. + if (searchSourceBuilder.fetchSource() != null) { + if (userAgent.contains(OPENSEARCH_DASHBOARDS_USER_AGENT)) { + return new FetchSourceContext( + true, + searchSourceBuilder.fetchSource().includes(), + searchSourceBuilder.fetchSource().excludes() + ); + } else { + String[] newArray = (String[]) ArrayUtils.addAll(searchSourceBuilder.fetchSource().excludes(), UI_METADATA_EXCLUDE); + return new FetchSourceContext(true, searchSourceBuilder.fetchSource().includes(), newArray); + } + } else if (!userAgent.contains(OPENSEARCH_DASHBOARDS_USER_AGENT)) { + return new FetchSourceContext(true, Strings.EMPTY_ARRAY, UI_METADATA_EXCLUDE); + } else { + return null; + } + } + + public static XContentParser createXContentParser(RestChannel channel, BytesReference bytesReference) throws IOException { + return XContentHelper + .createParser(channel.request().getXContentRegistry(), LoggingDeprecationHandler.INSTANCE, bytesReference, XContentType.JSON); + } + + public static XContentParser createXContentParserFromRegistry(NamedXContentRegistry xContentRegistry, BytesReference bytesReference) + throws IOException { + return XContentHelper.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, bytesReference, XContentType.JSON); + } + + /** + * Check if there is configuration/syntax error in feature definition of config + * @param config config to check + * @param maxFeatures max allowed feature number + * @return error message if error exists; otherwise, null is returned + */ + public static String checkFeaturesSyntax(Config config, int maxFeatures) { + List features = config.getFeatureAttributes(); + if (features != null) { + if (features.size() > maxFeatures) { + return "Can't create more than " + maxFeatures + " features"; + } + return validateFeaturesConfig(config.getFeatureAttributes()); + } + return null; + } + + private static String validateFeaturesConfig(List features) { + final Set duplicateFeatureNames = new HashSet<>(); + final Set featureNames = new HashSet<>(); + final Set duplicateFeatureAggNames = new HashSet<>(); + final Set featureAggNames = new HashSet<>(); + + features.forEach(feature -> { + if (!featureNames.add(feature.getName())) { + duplicateFeatureNames.add(feature.getName()); + } + if (!featureAggNames.add(feature.getAggregation().getName())) { + duplicateFeatureAggNames.add(feature.getAggregation().getName()); + } + }); + + StringBuilder errorMsgBuilder = new StringBuilder(); + if (duplicateFeatureNames.size() > 0) { + errorMsgBuilder.append("There are duplicate feature names: "); + errorMsgBuilder.append(String.join(", ", duplicateFeatureNames)); + } + if (errorMsgBuilder.length() != 0 && duplicateFeatureAggNames.size() > 0) { + errorMsgBuilder.append(". "); + } + if (duplicateFeatureAggNames.size() > 0) { + errorMsgBuilder.append(CommonMessages.DUPLICATE_FEATURE_AGGREGATION_NAMES); + errorMsgBuilder.append(String.join(", ", duplicateFeatureAggNames)); + } + return errorMsgBuilder.toString(); + } + + public static boolean isExceptionCausedByInvalidQuery(Exception ex) { + if (!(ex instanceof SearchPhaseExecutionException)) { + return false; + } + SearchPhaseExecutionException exception = (SearchPhaseExecutionException) ex; + // If any shards return bad request and failure cause is IllegalArgumentException, we + // consider the feature query is invalid and will not count the error in failure stats. + for (ShardSearchFailure failure : exception.shardFailures()) { + if (RestStatus.BAD_REQUEST != failure.status() || !(failure.getCause() instanceof IllegalArgumentException)) { + return false; + } + } + return true; + } + + /** + * Wrap action listener to avoid return verbose error message and wrong 500 error to user. + * Suggestion for exception handling in timeseries analysis (e.g., AD and Forecast): + * 1. If the error is caused by wrong input, throw IllegalArgumentException exception. + * 2. For other errors, please use TimeSeriesException or its subclass, or use + * OpenSearchStatusException. + * + * TODO: tune this function for wrapped exception, return root exception error message + * + * @param actionListener action listener + * @param generalErrorMessage general error message + * @param action listener response type + * @return wrapped action listener + */ + public static ActionListener wrapRestActionListener(ActionListener actionListener, String generalErrorMessage) { + return ActionListener.wrap(r -> { actionListener.onResponse(r); }, e -> { + logger.error("Wrap exception before sending back to user", e); + Throwable cause = Throwables.getRootCause(e); + if (isProperExceptionToReturn(e)) { + actionListener.onFailure(e); + } else if (isProperExceptionToReturn(cause)) { + actionListener.onFailure((Exception) cause); + } else { + RestStatus status = isBadRequest(e) ? BAD_REQUEST : INTERNAL_SERVER_ERROR; + String errorMessage = generalErrorMessage; + if (isBadRequest(e) || e instanceof TimeSeriesException) { + errorMessage = e.getMessage(); + } else if (cause != null && (isBadRequest(cause) || cause instanceof TimeSeriesException)) { + errorMessage = cause.getMessage(); + } + actionListener.onFailure(new OpenSearchStatusException(errorMessage, status)); + } + }); + } + + public static boolean isBadRequest(Throwable e) { + if (e == null) { + return false; + } + return e instanceof IllegalArgumentException || e instanceof ResourceNotFoundException; + } + + public static boolean isProperExceptionToReturn(Throwable e) { + if (e == null) { + return false; + } + return e instanceof OpenSearchStatusException || e instanceof IndexNotFoundException || e instanceof InvalidIndexNameException; + } + + private static String coalesceToEmpty(@Nullable String s) { + return s == null ? "" : s; + } +} diff --git a/src/main/resources/META-INF/services/org.opensearch.jobscheduler.spi.JobSchedulerExtension b/src/main/resources/META-INF/services/org.opensearch.jobscheduler.spi.JobSchedulerExtension index 627699843..01c2dfbe9 100644 --- a/src/main/resources/META-INF/services/org.opensearch.jobscheduler.spi.JobSchedulerExtension +++ b/src/main/resources/META-INF/services/org.opensearch.jobscheduler.spi.JobSchedulerExtension @@ -1 +1 @@ -org.opensearch.ad.AnomalyDetectorPlugin +org.opensearch.timeseries.TimeSeriesAnalyticsPlugin diff --git a/src/main/resources/es-plugin.properties b/src/main/resources/es-plugin.properties index a8dc4e91e..061a6f07d 100644 --- a/src/main/resources/es-plugin.properties +++ b/src/main/resources/es-plugin.properties @@ -9,5 +9,5 @@ # GitHub history for details. # -plugin=org.opensearch.ad.AnomalyDetectorPlugin +plugin=org.opensearch.timeseries.TimeSeriesAnalyticsPlugin version=${project.version} \ No newline at end of file diff --git a/src/test/java/org/opensearch/StreamInputOutputTests.java b/src/test/java/org/opensearch/StreamInputOutputTests.java index 1269c48f0..a1906c43f 100644 --- a/src/test/java/org/opensearch/StreamInputOutputTests.java +++ b/src/test/java/org/opensearch/StreamInputOutputTests.java @@ -39,8 +39,8 @@ import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.transport.TransportAddress; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.model.Entity; diff --git a/src/test/java/org/opensearch/StreamInputOutputTests.java-e b/src/test/java/org/opensearch/StreamInputOutputTests.java-e new file mode 100644 index 000000000..fa39da692 --- /dev/null +++ b/src/test/java/org/opensearch/StreamInputOutputTests.java-e @@ -0,0 +1,293 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; +import static org.hamcrest.Matchers.equalTo; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.opensearch.action.FailedNodeException; +import org.opensearch.ad.model.EntityProfileName; +import org.opensearch.ad.model.ModelProfile; +import org.opensearch.ad.model.ModelProfileOnNode; +import org.opensearch.ad.transport.EntityProfileAction; +import org.opensearch.ad.transport.EntityProfileRequest; +import org.opensearch.ad.transport.EntityProfileResponse; +import org.opensearch.ad.transport.EntityResultRequest; +import org.opensearch.ad.transport.ProfileNodeResponse; +import org.opensearch.ad.transport.ProfileResponse; +import org.opensearch.ad.transport.RCFResultResponse; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.transport.TransportAddress; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.model.Entity; + +/** + * Put in core package so that we can using Version's package private constructor + * + */ +public class StreamInputOutputTests extends AbstractTimeSeriesTest { + // public static Version V_1_1_0 = new Version(1010099, org.apache.lucene.util.Version.LUCENE_8_8_2); + private EntityResultRequest entityResultRequest; + private String detectorId; + private long start, end; + private Map entities; + private BytesStreamOutput output; + private String categoryField, categoryValue, categoryValue2; + private double[] feature; + private EntityProfileRequest entityProfileRequest; + private Entity entity, entity2; + private Set profilesToCollect; + private String nodeId = "abc"; + private String modelId = "123"; + private long modelSize = 712480L; + private long modelSize2 = 112480L; + private EntityProfileResponse entityProfileResponse; + private ProfileResponse profileResponse; + private RCFResultResponse rcfResultResponse; + + private boolean areEqualWithArrayValue(Map first, Map second) { + if (first.size() != second.size()) { + return false; + } + + return first.entrySet().stream().allMatch(e -> Arrays.equals(e.getValue(), second.get(e.getKey()))); + } + + @Override + public void setUp() throws Exception { + super.setUp(); + + categoryField = "a"; + categoryValue = "b"; + categoryValue2 = "b2"; + + feature = new double[] { 0.3 }; + detectorId = "123"; + + entity = Entity.createSingleAttributeEntity(categoryField, categoryValue); + entity2 = Entity.createSingleAttributeEntity(categoryField, categoryValue2); + + output = new BytesStreamOutput(); + } + + private void setUpEntityResultRequest() { + entities = new HashMap<>(); + entities.put(entity, feature); + start = 10L; + end = 20L; + entityResultRequest = new EntityResultRequest(detectorId, entities, start, end); + } + + /** + * @throws IOException when serialization/deserialization has issues. + */ + public void testDeSerializeEntityResultRequest() throws IOException { + setUpEntityResultRequest(); + + entityResultRequest.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + EntityResultRequest readRequest = new EntityResultRequest(streamInput); + assertThat(readRequest.getId(), equalTo(detectorId)); + assertThat(readRequest.getStart(), equalTo(start)); + assertThat(readRequest.getEnd(), equalTo(end)); + assertTrue(areEqualWithArrayValue(readRequest.getEntities(), entities)); + } + + private void setUpEntityProfileRequest() { + profilesToCollect = new HashSet(); + profilesToCollect.add(EntityProfileName.STATE); + entityProfileRequest = new EntityProfileRequest(detectorId, entity, profilesToCollect); + } + + /** + * @throws IOException when serialization/deserialization has issues. + */ + public void testDeserializeEntityProfileRequest() throws IOException { + setUpEntityProfileRequest(); + + entityProfileRequest.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + EntityProfileRequest readRequest = new EntityProfileRequest(streamInput); + assertThat(readRequest.getAdID(), equalTo(detectorId)); + assertThat(readRequest.getEntityValue(), equalTo(entity)); + assertThat(readRequest.getProfilesToCollect(), equalTo(profilesToCollect)); + } + + private void setUpEntityProfileResponse() { + long lastActiveTimestamp = 10L; + EntityProfileResponse.Builder builder = new EntityProfileResponse.Builder(); + builder.setLastActiveMs(lastActiveTimestamp).build(); + ModelProfile modelProfile = new ModelProfile(modelId, entity, modelSize); + ModelProfileOnNode model = new ModelProfileOnNode(nodeId, modelProfile); + builder.setModelProfile(model); + entityProfileResponse = builder.build(); + } + + /** + * @throws IOException when serialization/deserialization has issues. + */ + public void testDeserializeEntityProfileResponse() throws IOException { + setUpEntityProfileResponse(); + + entityProfileResponse.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + EntityProfileResponse readResponse = EntityProfileAction.INSTANCE.getResponseReader().read(streamInput); + assertThat(readResponse.getModelProfile(), equalTo(entityProfileResponse.getModelProfile())); + assertThat(readResponse.getLastActiveMs(), equalTo(entityProfileResponse.getLastActiveMs())); + assertThat(readResponse.getTotalUpdates(), equalTo(entityProfileResponse.getTotalUpdates())); + } + + @SuppressWarnings("serial") + private void setUpProfileResponse() { + String node1 = "node1"; + String nodeName1 = "nodename1"; + DiscoveryNode discoveryNode1_1 = new DiscoveryNode( + nodeName1, + node1, + new TransportAddress(TransportAddress.META_ADDRESS, 9300), + emptyMap(), + emptySet(), + Version.V_2_1_0 + ); + + String node2 = "node2"; + String nodeName2 = "nodename2"; + DiscoveryNode discoveryNode2 = new DiscoveryNode( + nodeName2, + node2, + new TransportAddress(TransportAddress.META_ADDRESS, 9301), + emptyMap(), + emptySet(), + Version.V_2_1_0 + ); + + String model1Id = "model1"; + String model2Id = "model2"; + + Map modelSizeMap1 = new HashMap() { + { + put(model1Id, modelSize); + put(model2Id, modelSize2); + } + }; + Map modelSizeMap2 = new HashMap(); + + int shingleSize = 8; + + ModelProfile modelProfile = new ModelProfile(model1Id, entity, modelSize); + ModelProfile modelProfile2 = new ModelProfile(model2Id, entity2, modelSize2); + + ProfileNodeResponse profileNodeResponse1 = new ProfileNodeResponse( + discoveryNode1_1, + modelSizeMap1, + shingleSize, + 0, + 0, + Arrays.asList(modelProfile, modelProfile2), + modelSizeMap1.size() + ); + ProfileNodeResponse profileNodeResponse2 = new ProfileNodeResponse( + discoveryNode2, + modelSizeMap2, + -1, + 0, + 0, + new ArrayList<>(), + modelSizeMap2.size() + ); + ProfileNodeResponse profileNodeResponse3 = new ProfileNodeResponse( + discoveryNode2, + null, + -1, + 0, + 0, + // null model size. Test if we can handle this case + null, + modelSizeMap2.size() + ); + List profileNodeResponses = Arrays.asList(profileNodeResponse1, profileNodeResponse2, profileNodeResponse3); + List failures = Collections.emptyList(); + + ClusterName clusterName = new ClusterName("test-cluster-name"); + profileResponse = new ProfileResponse(clusterName, profileNodeResponses, failures); + } + + /** + * @throws IOException when serialization/deserialization has issues. + */ + public void testDeserializeProfileResponse() throws IOException { + setUpProfileResponse(); + + profileResponse.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + ProfileResponse readResponse = new ProfileResponse(streamInput); + assertThat(readResponse.getModelProfile(), equalTo(profileResponse.getModelProfile())); + assertThat(readResponse.getShingleSize(), equalTo(profileResponse.getShingleSize())); + assertThat(readResponse.getActiveEntities(), equalTo(profileResponse.getActiveEntities())); + assertThat(readResponse.getTotalUpdates(), equalTo(profileResponse.getTotalUpdates())); + assertThat(readResponse.getCoordinatingNode(), equalTo(profileResponse.getCoordinatingNode())); + assertThat(readResponse.getTotalSizeInBytes(), equalTo(profileResponse.getTotalSizeInBytes())); + assertThat(readResponse.getModelCount(), equalTo(profileResponse.getModelCount())); + } + + private void setUpRCFResultResponse() { + rcfResultResponse = new RCFResultResponse( + 0.345, + 0.123, + 30, + new double[] { 0.3, 0.7 }, + 134, + 0.4, + Version.CURRENT, + randomIntBetween(-3, 0), + new double[] { randomDoubleBetween(0, 1.0, true), randomDoubleBetween(0, 1.0, true) }, + new double[][] { new double[] { randomDouble(), randomDouble() } }, + new double[] { randomDoubleBetween(0, 1.0, true), randomDoubleBetween(0, 1.0, true) }, + randomDoubleBetween(1.1, 10.0, true) + ); + } + + /** + * @throws IOException when serialization/deserialization has issues. + */ + public void testDeserializeRCFResultResponse() throws IOException { + setUpRCFResultResponse(); + + rcfResultResponse.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + RCFResultResponse readResponse = new RCFResultResponse(streamInput); + assertArrayEquals(readResponse.getAttribution(), rcfResultResponse.getAttribution(), 0.001); + assertThat(readResponse.getConfidence(), equalTo(rcfResultResponse.getConfidence())); + assertThat(readResponse.getForestSize(), equalTo(rcfResultResponse.getForestSize())); + assertThat(readResponse.getTotalUpdates(), equalTo(rcfResultResponse.getTotalUpdates())); + assertThat(readResponse.getRCFScore(), equalTo(rcfResultResponse.getRCFScore())); + } +} diff --git a/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java-e b/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java-e new file mode 100644 index 000000000..e2904c319 --- /dev/null +++ b/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java-e @@ -0,0 +1,784 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.action.admin.indices.mapping.get; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Clock; +import java.util.Arrays; +import java.util.Locale; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.mockito.ArgumentCaptor; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionResponse; +import org.opensearch.action.ActionType; +import org.opensearch.action.get.GetAction; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.search.SearchAction; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.rest.handler.IndexAnomalyDetectorActionHandler; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.IndexAnomalyDetectorResponse; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.rest.RestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.transport.TransportService; + +/** + * + * we need to put the test in the same package of GetFieldMappingsResponse + * (org.opensearch.action.admin.indices.mapping.get) since its constructor is + * package private + * + */ +public class IndexAnomalyDetectorActionHandlerTests extends AbstractTimeSeriesTest { + static ThreadPool threadPool; + private ThreadContext threadContext; + private String TEXT_FIELD_TYPE = "text"; + private IndexAnomalyDetectorActionHandler handler; + private ClusterService clusterService; + private NodeClient clientMock; + private SecurityClientUtil clientUtil; + private TransportService transportService; + private ActionListener channel; + private ADIndexManagement anomalyDetectionIndices; + private String detectorId; + private Long seqNo; + private Long primaryTerm; + private AnomalyDetector detector; + private WriteRequest.RefreshPolicy refreshPolicy; + private TimeValue requestTimeout; + private Integer maxSingleEntityAnomalyDetectors; + private Integer maxMultiEntityAnomalyDetectors; + private Integer maxAnomalyFeatures; + private Settings settings; + private RestRequest.Method method; + private ADTaskManager adTaskManager; + private SearchFeatureDao searchFeatureDao; + private Clock clock; + + @BeforeClass + public static void beforeClass() { + threadPool = new TestThreadPool("IndexAnomalyDetectorJobActionHandlerTests"); + } + + @AfterClass + public static void afterClass() { + ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); + threadPool = null; + } + + @SuppressWarnings("unchecked") + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + + settings = Settings.EMPTY; + clusterService = mock(ClusterService.class); + clientMock = spy(new NodeClient(settings, threadPool)); + clock = mock(Clock.class); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, settings); + transportService = mock(TransportService.class); + + channel = mock(ActionListener.class); + + anomalyDetectionIndices = mock(ADIndexManagement.class); + when(anomalyDetectionIndices.doesConfigIndexExist()).thenReturn(true); + + detectorId = "123"; + seqNo = 0L; + primaryTerm = 0L; + + WriteRequest.RefreshPolicy refreshPolicy = WriteRequest.RefreshPolicy.IMMEDIATE; + + String field = "a"; + detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); + + requestTimeout = new TimeValue(1000L); + + maxSingleEntityAnomalyDetectors = 1000; + + maxMultiEntityAnomalyDetectors = 10; + + maxAnomalyFeatures = 5; + + method = RestRequest.Method.POST; + + adTaskManager = mock(ADTaskManager.class); + + searchFeatureDao = mock(SearchFeatureDao.class); + + handler = new IndexAnomalyDetectorActionHandler( + clusterService, + clientMock, + clientUtil, + transportService, + channel, + anomalyDetectionIndices, + detectorId, + seqNo, + primaryTerm, + refreshPolicy, + detector, + requestTimeout, + maxSingleEntityAnomalyDetectors, + maxMultiEntityAnomalyDetectors, + maxAnomalyFeatures, + method, + xContentRegistry(), + null, + adTaskManager, + searchFeatureDao, + Settings.EMPTY + ); + } + + // we support upto 2 category fields now + public void testThreeCategoricalFields() throws IOException { + expectThrows( + ValidationException.class, + () -> TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList("a", "b", "c")) + ); + } + + @SuppressWarnings("unchecked") + public void testMoreThanTenThousandSingleEntityDetectors() throws IOException { + SearchResponse mockResponse = mock(SearchResponse.class); + int totalHits = 1001; + when(mockResponse.getHits()).thenReturn(TestHelpers.createSearchHits(totalHits)); + SearchResponse detectorResponse = mock(SearchResponse.class); + when(detectorResponse.getHits()).thenReturn(TestHelpers.createSearchHits(totalHits)); + SearchResponse userIndexResponse = mock(SearchResponse.class); + int userIndexHits = 0; + when(userIndexResponse.getHits()).thenReturn(TestHelpers.createSearchHits(userIndexHits)); + + // extend NodeClient since its execute method is final and mockito does not allow to mock final methods + // we can also use spy to overstep the final methods + NodeClient client = getCustomNodeClient(detectorResponse, userIndexResponse, detector, threadPool); + NodeClient clientSpy = spy(client); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, settings); + + handler = new IndexAnomalyDetectorActionHandler( + clusterService, + clientSpy, + clientUtil, + transportService, + channel, + anomalyDetectionIndices, + detectorId, + seqNo, + primaryTerm, + refreshPolicy, + // no categorical feature + TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null, true), + requestTimeout, + maxSingleEntityAnomalyDetectors, + maxMultiEntityAnomalyDetectors, + maxAnomalyFeatures, + method, + xContentRegistry(), + null, + adTaskManager, + searchFeatureDao, + Settings.EMPTY + ); + + handler.start(); + ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); + verify(clientMock, never()).execute(eq(GetMappingsAction.INSTANCE), any(), any()); + verify(channel).onFailure(response.capture()); + Exception value = response.getValue(); + assertTrue(value instanceof IllegalArgumentException); + String errorMsg = String + .format( + Locale.ROOT, + IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_SINGLE_ENTITY_DETECTORS_PREFIX_MSG, + maxSingleEntityAnomalyDetectors + ); + assertTrue(value.getMessage().contains(errorMsg)); + } + + @SuppressWarnings("unchecked") + public void testTextField() throws IOException { + String field = "a"; + AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); + + SearchResponse detectorResponse = mock(SearchResponse.class); + int totalHits = 9; + when(detectorResponse.getHits()).thenReturn(TestHelpers.createSearchHits(totalHits)); + + // extend NodeClient since its execute method is final and mockito does not allow to mock final methods + // we can also use spy to overstep the final methods + NodeClient client = new NodeClient(Settings.EMPTY, threadPool) { + @Override + public void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + try { + if (action.equals(SearchAction.INSTANCE)) { + listener.onResponse((Response) detectorResponse); + } else { + // passes first get field mapping call where timestamp has to be of type date + // fails on second call where categorical field is checked to be type keyword or IP + // we need to put the test in the same package of GetFieldMappingsResponse since its constructor is package private + GetFieldMappingsResponse response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(detector.getIndices().get(0), field, "date") + ); + listener.onResponse((Response) response); + } + } catch (IOException e) { + logger.error("Create field mapping threw an exception", e); + } + } + }; + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); + + handler = new IndexAnomalyDetectorActionHandler( + clusterService, + client, + clientUtil, + transportService, + channel, + anomalyDetectionIndices, + detectorId, + seqNo, + primaryTerm, + refreshPolicy, + detector, + requestTimeout, + maxSingleEntityAnomalyDetectors, + maxMultiEntityAnomalyDetectors, + maxAnomalyFeatures, + method, + xContentRegistry(), + null, + adTaskManager, + searchFeatureDao, + Settings.EMPTY + ); + + ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); + + handler.start(); + + verify(channel).onFailure(response.capture()); + Exception value = response.getValue(); + assertTrue(value instanceof Exception); + assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.CATEGORICAL_FIELD_TYPE_ERR_MSG)); + } + + @SuppressWarnings("unchecked") + private void testValidTypeTemplate(String filedTypeName) throws IOException { + String field = "a"; + AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); + + SearchResponse detectorResponse = mock(SearchResponse.class); + int totalHits = 9; + when(detectorResponse.getHits()).thenReturn(TestHelpers.createSearchHits(totalHits)); + + SearchResponse userIndexResponse = mock(SearchResponse.class); + int userIndexHits = 0; + when(userIndexResponse.getHits()).thenReturn(TestHelpers.createSearchHits(userIndexHits)); + AtomicBoolean isPreCategoryMappingQuery = new AtomicBoolean(true); + // extend NodeClient since its execute method is final and mockito does not allow to mock final methods + // we can also use spy to overstep the final methods + NodeClient client = new NodeClient(Settings.EMPTY, threadPool) { + @Override + public void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + try { + if (action.equals(SearchAction.INSTANCE)) { + assertTrue(request instanceof SearchRequest); + SearchRequest searchRequest = (SearchRequest) request; + if (searchRequest.indices()[0].equals(CommonName.CONFIG_INDEX)) { + listener.onResponse((Response) detectorResponse); + } else { + listener.onResponse((Response) userIndexResponse); + } + } else if (isPreCategoryMappingQuery.get()) { + isPreCategoryMappingQuery.set(false); + GetFieldMappingsResponse response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(detector.getIndices().get(0), field, "date") + ); + listener.onResponse((Response) response); + } else { + GetFieldMappingsResponse response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(detector.getIndices().get(0), field, filedTypeName) + ); + listener.onResponse((Response) response); + } + } catch (IOException e) { + logger.error("Create field mapping threw an exception", e); + } + } + }; + + NodeClient clientSpy = spy(client); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); + + handler = new IndexAnomalyDetectorActionHandler( + clusterService, + clientSpy, + clientUtil, + transportService, + channel, + anomalyDetectionIndices, + detectorId, + seqNo, + primaryTerm, + refreshPolicy, + detector, + requestTimeout, + maxSingleEntityAnomalyDetectors, + maxMultiEntityAnomalyDetectors, + maxAnomalyFeatures, + method, + xContentRegistry(), + null, + adTaskManager, + searchFeatureDao, + Settings.EMPTY + ); + + ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); + + handler.start(); + + verify(clientSpy, times(2)).execute(eq(GetFieldMappingsAction.INSTANCE), any(), any()); + verify(channel).onFailure(response.capture()); + Exception value = response.getValue(); + assertTrue(value instanceof IllegalArgumentException); + assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.NO_DOCS_IN_USER_INDEX_MSG)); + } + + public void testIpField() throws IOException { + testValidTypeTemplate(CommonName.IP_TYPE); + } + + public void testKeywordField() throws IOException { + testValidTypeTemplate(CommonName.KEYWORD_TYPE); + } + + @SuppressWarnings("unchecked") + private void testUpdateTemplate(String fieldTypeName) throws IOException { + String field = "a"; + AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); + + SearchResponse detectorResponse = mock(SearchResponse.class); + int totalHits = 9; + when(detectorResponse.getHits()).thenReturn(TestHelpers.createSearchHits(totalHits)); + + GetResponse getDetectorResponse = TestHelpers.createGetResponse(detector, detector.getId(), CommonName.CONFIG_INDEX); + + SearchResponse userIndexResponse = mock(SearchResponse.class); + int userIndexHits = 0; + when(userIndexResponse.getHits()).thenReturn(TestHelpers.createSearchHits(userIndexHits)); + + // extend NodeClient since its execute method is final and mockito does not allow to mock final methods + // we can also use spy to overstep the final methods + NodeClient client = new NodeClient(Settings.EMPTY, threadPool) { + @Override + public void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + try { + if (action.equals(SearchAction.INSTANCE)) { + assertTrue(request instanceof SearchRequest); + SearchRequest searchRequest = (SearchRequest) request; + if (searchRequest.indices()[0].equals(CommonName.CONFIG_INDEX)) { + listener.onResponse((Response) detectorResponse); + } else { + listener.onResponse((Response) userIndexResponse); + } + } else if (action.equals(GetAction.INSTANCE)) { + assertTrue(request instanceof GetRequest); + listener.onResponse((Response) getDetectorResponse); + } else { + GetFieldMappingsResponse response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(detector.getIndices().get(0), field, fieldTypeName) + ); + listener.onResponse((Response) response); + } + } catch (IOException e) { + logger.error("Create field mapping threw an exception", e); + } + } + }; + + NodeClient clientSpy = spy(client); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); + ClusterName clusterName = new ClusterName("test"); + ClusterState clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().build()).build(); + when(clusterService.state()).thenReturn(clusterState); + + handler = new IndexAnomalyDetectorActionHandler( + clusterService, + clientSpy, + clientUtil, + transportService, + channel, + anomalyDetectionIndices, + detectorId, + seqNo, + primaryTerm, + refreshPolicy, + detector, + requestTimeout, + maxSingleEntityAnomalyDetectors, + maxMultiEntityAnomalyDetectors, + maxAnomalyFeatures, + RestRequest.Method.PUT, + xContentRegistry(), + null, + adTaskManager, + searchFeatureDao, + Settings.EMPTY + ); + + ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); + + handler.start(); + + verify(clientSpy, times(1)).execute(eq(GetFieldMappingsAction.INSTANCE), any(), any()); + verify(channel).onFailure(response.capture()); + Exception value = response.getValue(); + if (fieldTypeName.equals(CommonName.IP_TYPE) || fieldTypeName.equals(CommonName.KEYWORD_TYPE)) { + assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.NO_DOCS_IN_USER_INDEX_MSG)); + } else { + assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.CATEGORICAL_FIELD_TYPE_ERR_MSG)); + } + } + + @Ignore + public void testUpdateIpField() throws IOException { + testUpdateTemplate(CommonName.IP_TYPE); + } + + @Ignore + public void testUpdateKeywordField() throws IOException { + testUpdateTemplate(CommonName.KEYWORD_TYPE); + } + + @Ignore + public void testUpdateTextField() throws IOException { + testUpdateTemplate(TEXT_FIELD_TYPE); + } + + public static NodeClient getCustomNodeClient( + SearchResponse detectorResponse, + SearchResponse userIndexResponse, + AnomalyDetector detector, + ThreadPool pool + ) { + return new NodeClient(Settings.EMPTY, pool) { + @Override + public void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + try { + if (action.equals(SearchAction.INSTANCE)) { + assertTrue(request instanceof SearchRequest); + SearchRequest searchRequest = (SearchRequest) request; + if (searchRequest.indices()[0].equals(CommonName.CONFIG_INDEX)) { + listener.onResponse((Response) detectorResponse); + } else { + listener.onResponse((Response) userIndexResponse); + } + } else { + GetFieldMappingsResponse response = new GetFieldMappingsResponse( + TestHelpers.createFieldMappings(detector.getIndices().get(0), "timestamp", "date") + ); + listener.onResponse((Response) response); + } + } catch (IOException e) { + logger.error("Create field mapping threw an exception", e); + } + } + }; + } + + @SuppressWarnings("unchecked") + public void testMoreThanTenMultiEntityDetectors() throws IOException { + String field = "a"; + AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); + SearchResponse detectorResponse = mock(SearchResponse.class); + int totalHits = 11; + when(detectorResponse.getHits()).thenReturn(TestHelpers.createSearchHits(totalHits)); + + SearchResponse userIndexResponse = mock(SearchResponse.class); + int userIndexHits = 0; + when(userIndexResponse.getHits()).thenReturn(TestHelpers.createSearchHits(userIndexHits)); + // extend NodeClient since its execute method is final and mockito does not allow to mock final methods + // we can also use spy to overstep the final methods + NodeClient client = getCustomNodeClient(detectorResponse, userIndexResponse, detector, threadPool); + NodeClient clientSpy = spy(client); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, settings); + + handler = new IndexAnomalyDetectorActionHandler( + clusterService, + clientSpy, + clientUtil, + transportService, + channel, + anomalyDetectionIndices, + detectorId, + seqNo, + primaryTerm, + refreshPolicy, + detector, + requestTimeout, + maxSingleEntityAnomalyDetectors, + maxMultiEntityAnomalyDetectors, + maxAnomalyFeatures, + method, + xContentRegistry(), + null, + adTaskManager, + searchFeatureDao, + Settings.EMPTY + ); + + handler.start(); + ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); + verify(clientSpy, times(1)).search(any(SearchRequest.class), any()); + verify(channel).onFailure(response.capture()); + Exception value = response.getValue(); + assertTrue(value instanceof IllegalArgumentException); + String errorMsg = String + .format( + Locale.ROOT, + IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG, + maxMultiEntityAnomalyDetectors + ); + assertTrue(value.getMessage().contains(errorMsg)); + } + + @Ignore + @SuppressWarnings("unchecked") + public void testTenMultiEntityDetectorsUpdateSingleEntityAdToMulti() throws IOException { + int totalHits = 10; + AnomalyDetector existingDetector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, null); + GetResponse getDetectorResponse = TestHelpers + .createGetResponse(existingDetector, existingDetector.getId(), CommonName.CONFIG_INDEX); + + SearchResponse searchResponse = mock(SearchResponse.class); + when(searchResponse.getHits()).thenReturn(TestHelpers.createSearchHits(totalHits)); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length == 2 + ); + + assertTrue(args[0] instanceof SearchRequest); + assertTrue(args[1] instanceof ActionListener); + + ActionListener listener = (ActionListener) args[1]; + + listener.onResponse(searchResponse); + + return null; + }).when(clientMock).search(any(SearchRequest.class), any()); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length == 2 + ); + + assertTrue(args[0] instanceof GetRequest); + assertTrue(args[1] instanceof ActionListener); + + ActionListener listener = (ActionListener) args[1]; + + listener.onResponse(getDetectorResponse); + + return null; + }).when(clientMock).get(any(GetRequest.class), any()); + + ClusterName clusterName = new ClusterName("test"); + ClusterState clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().build()).build(); + when(clusterService.state()).thenReturn(clusterState); + + handler = new IndexAnomalyDetectorActionHandler( + clusterService, + clientMock, + clientUtil, + transportService, + channel, + anomalyDetectionIndices, + detectorId, + seqNo, + primaryTerm, + refreshPolicy, + detector, + requestTimeout, + maxSingleEntityAnomalyDetectors, + maxMultiEntityAnomalyDetectors, + maxAnomalyFeatures, + RestRequest.Method.PUT, + xContentRegistry(), + null, + adTaskManager, + searchFeatureDao, + Settings.EMPTY + ); + + handler.start(); + + ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); + verify(clientMock, times(1)).search(any(SearchRequest.class), any()); + verify(clientMock, times(1)).get(any(GetRequest.class), any()); + verify(channel).onFailure(response.capture()); + Exception value = response.getValue(); + assertTrue(value instanceof IllegalArgumentException); + assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG)); + } + + @Ignore + @SuppressWarnings("unchecked") + public void testTenMultiEntityDetectorsUpdateExistingMultiEntityAd() throws IOException { + int totalHits = 10; + AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList("a")); + GetResponse getDetectorResponse = TestHelpers.createGetResponse(detector, detector.getId(), CommonName.CONFIG_INDEX); + + SearchResponse searchResponse = mock(SearchResponse.class); + when(searchResponse.getHits()).thenReturn(TestHelpers.createSearchHits(totalHits)); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length == 2 + ); + + assertTrue(args[0] instanceof SearchRequest); + assertTrue(args[1] instanceof ActionListener); + + ActionListener listener = (ActionListener) args[1]; + + listener.onResponse(searchResponse); + + return null; + }).when(clientMock).search(any(SearchRequest.class), any()); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length == 2 + ); + + assertTrue(args[0] instanceof GetRequest); + assertTrue(args[1] instanceof ActionListener); + + ActionListener listener = (ActionListener) args[1]; + + listener.onResponse(getDetectorResponse); + + return null; + }).when(clientMock).get(any(GetRequest.class), any()); + + ClusterName clusterName = new ClusterName("test"); + ClusterState clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().build()).build(); + when(clusterService.state()).thenReturn(clusterState); + + handler = new IndexAnomalyDetectorActionHandler( + clusterService, + clientMock, + clientUtil, + transportService, + channel, + anomalyDetectionIndices, + detectorId, + seqNo, + primaryTerm, + refreshPolicy, + detector, + requestTimeout, + maxSingleEntityAnomalyDetectors, + maxMultiEntityAnomalyDetectors, + maxAnomalyFeatures, + RestRequest.Method.PUT, + xContentRegistry(), + null, + adTaskManager, + searchFeatureDao, + Settings.EMPTY + ); + + handler.start(); + + ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); + verify(clientMock, times(0)).search(any(SearchRequest.class), any()); + verify(clientMock, times(1)).get(any(GetRequest.class), any()); + verify(channel).onFailure(response.capture()); + Exception value = response.getValue(); + // make sure execution passes all necessary checks + assertTrue(value instanceof IllegalStateException); + assertTrue(value.getMessage().contains("NodeClient has not been initialized")); + } +} diff --git a/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateAnomalyDetectorActionHandlerTests.java-e b/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateAnomalyDetectorActionHandlerTests.java-e new file mode 100644 index 000000000..2869943b6 --- /dev/null +++ b/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateAnomalyDetectorActionHandlerTests.java-e @@ -0,0 +1,236 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.action.admin.indices.mapping.get; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Clock; +import java.util.Arrays; +import java.util.Locale; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.rest.handler.AbstractAnomalyDetectorActionHandler; +import org.opensearch.ad.rest.handler.IndexAnomalyDetectorActionHandler; +import org.opensearch.ad.rest.handler.ValidateAnomalyDetectorActionHandler; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.client.Client; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.rest.RestRequest; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableList; + +public class ValidateAnomalyDetectorActionHandlerTests extends AbstractTimeSeriesTest { + + protected AbstractAnomalyDetectorActionHandler handler; + protected ClusterService clusterService; + protected ActionListener channel; + protected TransportService transportService; + protected ADIndexManagement anomalyDetectionIndices; + protected String detectorId; + protected Long seqNo; + protected Long primaryTerm; + protected AnomalyDetector detector; + protected WriteRequest.RefreshPolicy refreshPolicy; + protected TimeValue requestTimeout; + protected Integer maxSingleEntityAnomalyDetectors; + protected Integer maxMultiEntityAnomalyDetectors; + protected Integer maxAnomalyFeatures; + protected Settings settings; + protected RestRequest.Method method; + protected ADTaskManager adTaskManager; + protected SearchFeatureDao searchFeatureDao; + protected Clock clock; + + @Mock + private Client clientMock; + @Mock + protected ThreadPool threadPool; + protected ThreadContext threadContext; + + @SuppressWarnings("unchecked") + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + MockitoAnnotations.initMocks(this); + + settings = Settings.EMPTY; + clusterService = mock(ClusterService.class); + channel = mock(ActionListener.class); + transportService = mock(TransportService.class); + + anomalyDetectionIndices = mock(ADIndexManagement.class); + when(anomalyDetectionIndices.doesConfigIndexExist()).thenReturn(true); + + detectorId = "123"; + seqNo = 0L; + primaryTerm = 0L; + clock = mock(Clock.class); + + refreshPolicy = WriteRequest.RefreshPolicy.IMMEDIATE; + + String field = "a"; + detector = TestHelpers + .randomAnomalyDetectorUsingCategoryFields(detectorId, "timestamp", ImmutableList.of("test-index"), Arrays.asList(field)); + + requestTimeout = new TimeValue(1000L); + maxSingleEntityAnomalyDetectors = 1000; + maxMultiEntityAnomalyDetectors = 10; + maxAnomalyFeatures = 5; + method = RestRequest.Method.POST; + adTaskManager = mock(ADTaskManager.class); + searchFeatureDao = mock(SearchFeatureDao.class); + + threadContext = new ThreadContext(settings); + Mockito.doReturn(threadPool).when(clientMock).threadPool(); + Mockito.doReturn(threadContext).when(threadPool).getThreadContext(); + } + + @SuppressWarnings("unchecked") + public void testValidateMoreThanThousandSingleEntityDetectorLimit() throws IOException { + SearchResponse mockResponse = mock(SearchResponse.class); + int totalHits = maxSingleEntityAnomalyDetectors + 1; + when(mockResponse.getHits()).thenReturn(TestHelpers.createSearchHits(totalHits)); + SearchResponse detectorResponse = mock(SearchResponse.class); + when(detectorResponse.getHits()).thenReturn(TestHelpers.createSearchHits(totalHits)); + SearchResponse userIndexResponse = mock(SearchResponse.class); + int userIndexHits = 0; + when(userIndexResponse.getHits()).thenReturn(TestHelpers.createSearchHits(userIndexHits)); + AnomalyDetector singleEntityDetector = TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null, true); + + // extend NodeClient since its execute method is final and mockito does not allow to mock final methods + // we can also use spy to overstep the final methods + NodeClient client = IndexAnomalyDetectorActionHandlerTests + .getCustomNodeClient(detectorResponse, userIndexResponse, singleEntityDetector, threadPool); + + NodeClient clientSpy = spy(client); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + SecurityClientUtil clientUtil = new SecurityClientUtil(nodeStateManager, settings); + + handler = new ValidateAnomalyDetectorActionHandler( + clusterService, + clientSpy, + clientUtil, + channel, + anomalyDetectionIndices, + singleEntityDetector, + requestTimeout, + maxSingleEntityAnomalyDetectors, + maxMultiEntityAnomalyDetectors, + maxAnomalyFeatures, + method, + xContentRegistry(), + null, + searchFeatureDao, + ValidationAspect.DETECTOR.getName(), + clock, + settings + ); + handler.start(); + ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); + verify(clientSpy, never()).execute(eq(GetMappingsAction.INSTANCE), any(), any()); + verify(channel).onFailure(response.capture()); + Exception value = response.getValue(); + assertTrue(value instanceof ValidationException); + String errorMsg = String + .format( + Locale.ROOT, + IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_SINGLE_ENTITY_DETECTORS_PREFIX_MSG, + maxSingleEntityAnomalyDetectors + ); + assertTrue(value.getMessage().contains(errorMsg)); + } + + @SuppressWarnings("unchecked") + public void testValidateMoreThanTenMultiEntityDetectorsLimit() throws IOException { + String field = "a"; + AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); + + SearchResponse detectorResponse = mock(SearchResponse.class); + int totalHits = maxMultiEntityAnomalyDetectors + 1; + when(detectorResponse.getHits()).thenReturn(TestHelpers.createSearchHits(totalHits)); + + SearchResponse userIndexResponse = mock(SearchResponse.class); + int userIndexHits = 0; + when(userIndexResponse.getHits()).thenReturn(TestHelpers.createSearchHits(userIndexHits)); + // extend NodeClient since its execute method is final and mockito does not allow to mock final methods + // we can also use spy to overstep the final methods + NodeClient client = IndexAnomalyDetectorActionHandlerTests + .getCustomNodeClient(detectorResponse, userIndexResponse, detector, threadPool); + NodeClient clientSpy = spy(client); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + SecurityClientUtil clientUtil = new SecurityClientUtil(nodeStateManager, settings); + + handler = new ValidateAnomalyDetectorActionHandler( + clusterService, + clientSpy, + clientUtil, + channel, + anomalyDetectionIndices, + detector, + requestTimeout, + maxSingleEntityAnomalyDetectors, + maxMultiEntityAnomalyDetectors, + maxAnomalyFeatures, + method, + xContentRegistry(), + null, + searchFeatureDao, + "", + clock, + Settings.EMPTY + ); + handler.start(); + ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); + verify(clientSpy, never()).execute(eq(GetMappingsAction.INSTANCE), any(), any()); + verify(channel).onFailure(response.capture()); + Exception value = response.getValue(); + assertTrue(value instanceof ValidationException); + String errorMsg = String + .format( + Locale.ROOT, + IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG, + maxMultiEntityAnomalyDetectors + ); + assertTrue(value.getMessage().contains(errorMsg)); + } +} diff --git a/src/test/java/org/opensearch/ad/ADIntegTestCase.java b/src/test/java/org/opensearch/ad/ADIntegTestCase.java index a1b8daacf..992f137f5 100644 --- a/src/test/java/org/opensearch/ad/ADIntegTestCase.java +++ b/src/test/java/org/opensearch/ad/ADIntegTestCase.java @@ -50,17 +50,18 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.plugins.Plugin; -import org.opensearch.rest.RestStatus; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.test.transport.MockTransportService; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Feature; @@ -81,11 +82,11 @@ public abstract class ADIntegTestCase extends OpenSearchIntegTestCase { @Override protected Collection> nodePlugins() { - return Collections.singletonList(AnomalyDetectorPlugin.class); + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); } protected Collection> transportClientPlugins() { - return Collections.singletonList(AnomalyDetectorPlugin.class); + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); } @Override diff --git a/src/test/java/org/opensearch/ad/ADIntegTestCase.java-e b/src/test/java/org/opensearch/ad/ADIntegTestCase.java-e new file mode 100644 index 000000000..f4c45fc7f --- /dev/null +++ b/src/test/java/org/opensearch/ad/ADIntegTestCase.java-e @@ -0,0 +1,333 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; +import static org.opensearch.timeseries.util.RestHandlerUtils.XCONTENT_WITH_TYPE; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.core.Logger; +import org.junit.Before; +import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsRequest; +import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsResponse; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +import org.opensearch.action.bulk.BulkRequestBuilder; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.mock.plugin.MockReindexPlugin; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.client.Client; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.plugins.Plugin; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.test.transport.MockTransportService; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.util.RestHandlerUtils; + +import com.google.common.collect.ImmutableMap; + +public abstract class ADIntegTestCase extends OpenSearchIntegTestCase { + protected static final Logger LOG = (Logger) LogManager.getLogger(ADIntegTestCase.class); + + private long timeout = 5_000; + protected String timeField = "timestamp"; + protected String categoryField = "type"; + protected String ipField = "ip"; + protected String valueField = "value"; + protected String nameField = "test"; + protected int DEFAULT_TEST_DATA_DOCS = 3000; + + @Override + protected Collection> nodePlugins() { + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); + } + + protected Collection> transportClientPlugins() { + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); + } + + @Override + protected Collection> getMockPlugins() { + final ArrayList> plugins = new ArrayList<>(); + plugins.add(MockReindexPlugin.class); + plugins.addAll(super.getMockPlugins()); + plugins.remove(MockTransportService.TestPlugin.class); + return Collections.unmodifiableList(plugins); + } + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + } + + public void createDetectors(List detectors, boolean createIndexFirst) throws IOException { + if (createIndexFirst) { + createIndex(CommonName.CONFIG_INDEX, ADIndexManagement.getConfigMappings()); + } + + for (AnomalyDetector detector : detectors) { + indexDoc(CommonName.CONFIG_INDEX, detector.toXContent(jsonBuilder(), XCONTENT_WITH_TYPE)); + } + } + + public String createDetector(AnomalyDetector detector) throws IOException { + return indexDoc(CommonName.CONFIG_INDEX, detector.toXContent(jsonBuilder(), XCONTENT_WITH_TYPE)); + } + + public String createADResult(AnomalyResult adResult) throws IOException { + return indexDoc(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, adResult.toXContent(jsonBuilder(), XCONTENT_WITH_TYPE)); + } + + public String createADTask(ADTask adTask) throws IOException { + if (adTask.getTaskId() != null) { + return indexDoc(ADCommonName.DETECTION_STATE_INDEX, adTask.getTaskId(), adTask.toXContent(jsonBuilder(), XCONTENT_WITH_TYPE)); + } + return indexDoc(ADCommonName.DETECTION_STATE_INDEX, adTask.toXContent(jsonBuilder(), XCONTENT_WITH_TYPE)); + } + + public void createDetectorIndex() throws IOException { + createIndex(CommonName.CONFIG_INDEX, ADIndexManagement.getConfigMappings()); + } + + public void createADResultIndex() throws IOException { + createIndex(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, ADIndexManagement.getResultMappings()); + } + + public void createCustomADResultIndex(String indexName) throws IOException { + createIndex(indexName, ADIndexManagement.getResultMappings()); + } + + public void createDetectionStateIndex() throws IOException { + createIndex(ADCommonName.DETECTION_STATE_INDEX, ADIndexManagement.getStateMappings()); + } + + public void createTestDataIndex(String indexName) { + String mappings = "{\"properties\":{\"" + + timeField + + "\":{\"type\":\"date\",\"format\":\"strict_date_time||epoch_millis\"}," + + "\"value\":{\"type\":\"double\"}, \"" + + categoryField + + "\":{\"type\":\"keyword\"},\"" + + ipField + + "\":{\"type\":\"ip\"}," + + "\"is_error\":{\"type\":\"boolean\"}, \"message\":{\"type\":\"text\"}}}"; + createIndex(indexName, mappings); + } + + public void createIndex(String indexName, String mappings) { + CreateIndexResponse createIndexResponse = TestHelpers.createIndex(admin(), indexName, mappings); + assertEquals(true, createIndexResponse.isAcknowledged()); + } + + public AcknowledgedResponse deleteDetectorIndex() { + return deleteIndex(CommonName.CONFIG_INDEX); + } + + public AcknowledgedResponse deleteIndex(String indexName) { + DeleteIndexRequest deleteIndexRequest = new DeleteIndexRequest(indexName); + return admin().indices().delete(deleteIndexRequest).actionGet(timeout); + } + + public void deleteIndexIfExists(String indexName) { + if (indexExists(indexName)) { + DeleteIndexRequest deleteIndexRequest = new DeleteIndexRequest(indexName); + admin().indices().delete(deleteIndexRequest).actionGet(timeout); + } + } + + public String indexDoc(String indexName, XContentBuilder source) { + IndexRequest indexRequest = new IndexRequest(indexName).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).source(source); + IndexResponse indexResponse = client().index(indexRequest).actionGet(timeout); + assertEquals(RestStatus.CREATED, indexResponse.status()); + return indexResponse.getId(); + } + + public String indexDoc(String indexName, String id, XContentBuilder source) { + IndexRequest indexRequest = new IndexRequest(indexName) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .source(source) + .id(id); + IndexResponse indexResponse = client().index(indexRequest).actionGet(timeout); + assertEquals(RestStatus.CREATED, indexResponse.status()); + return indexResponse.getId(); + } + + public String indexDoc(String indexName, Map source) { + IndexRequest indexRequest = new IndexRequest(indexName).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).source(source); + IndexResponse indexResponse = client().index(indexRequest).actionGet(timeout); + assertEquals(RestStatus.CREATED, indexResponse.status()); + return indexResponse.getId(); + } + + public BulkResponse bulkIndexObjects(String indexName, List objects) { + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); + objects.forEach(obj -> { + try (XContentBuilder builder = jsonBuilder()) { + IndexRequest indexRequest = new IndexRequest(indexName) + .source(obj.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)); + bulkRequestBuilder.add(indexRequest); + } catch (Exception e) { + String error = "Failed to prepare request to bulk index docs"; + LOG.error(error, e); + throw new TimeSeriesException(error); + } + }); + return client().bulk(bulkRequestBuilder.request()).actionGet(timeout); + } + + public BulkResponse bulkIndexDocs(String indexName, List> docs, long timeout) { + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); + docs.forEach(doc -> bulkRequestBuilder.add(new IndexRequest(indexName).source(doc))); + return client().bulk(bulkRequestBuilder.request().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)).actionGet(timeout); + } + + public GetResponse getDoc(String indexName, String id) { + GetRequest getRequest = new GetRequest(indexName).id(id); + return client().get(getRequest).actionGet(timeout); + } + + public long countDocs(String indexName) { + SearchRequest request = new SearchRequest(); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(new MatchAllQueryBuilder()).size(0); + request.indices(indexName).source(searchSourceBuilder); + SearchResponse searchResponse = client().search(request).actionGet(timeout); + return searchResponse.getHits().getTotalHits().value; + } + + public long countDetectorDocs(String detectorId) { + SearchRequest request = new SearchRequest(); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(new TermQueryBuilder("detector_id", detectorId)).size(10); + request.indices(ADCommonName.DETECTION_STATE_INDEX).source(searchSourceBuilder); + SearchResponse searchResponse = client().search(request).actionGet(timeout); + return searchResponse.getHits().getTotalHits().value; + } + + public ClusterUpdateSettingsResponse updateTransientSettings(Map settings) { + ClusterUpdateSettingsRequest updateSettingsRequest = new ClusterUpdateSettingsRequest(); + updateSettingsRequest.transientSettings(settings); + return clusterAdmin().updateSettings(updateSettingsRequest).actionGet(timeout); + } + + public Map getDataNodes() { + DiscoveryNodes nodes = clusterService().state().getNodes(); + return nodes.getDataNodes(); + } + + public Client getDataNodeClient() { + for (Client client : clients()) { + if (client instanceof NodeClient) { + return client; + } + } + return null; + } + + public DiscoveryNode[] getDataNodesArray() { + DiscoveryNodes nodes = clusterService().state().getNodes(); + Collection nodeCollection = nodes.getDataNodes().values(); + List dataNodes = new ArrayList<>(); + for (DiscoveryNode node : nodeCollection) { + dataNodes.add(node); + } + return dataNodes.toArray(new DiscoveryNode[0]); + } + + public void ingestTestDataValidate(String testIndex, Instant startTime, int detectionIntervalInMinutes, String type) { + ingestTestDataValidate(testIndex, startTime, detectionIntervalInMinutes, type, DEFAULT_TEST_DATA_DOCS); + } + + public void ingestTestDataValidate(String testIndex, Instant startTime, int detectionIntervalInMinutes, String type, int totalDocs) { + createTestDataIndex(testIndex); + List> docs = new ArrayList<>(); + Instant currentInterval = Instant.from(startTime); + + for (int i = 0; i < totalDocs; i++) { + currentInterval = currentInterval.plus(detectionIntervalInMinutes, ChronoUnit.MINUTES); + double value = i % 500 == 0 ? randomDoubleBetween(1000, 2000, true) : randomDoubleBetween(10, 100, true); + docs + .add( + ImmutableMap + .of( + timeField, + currentInterval.toEpochMilli(), + "value", + value, + "type", + type, + "is_error", + randomBoolean(), + "message", + randomAlphaOfLength(5) + ) + ); + } + BulkResponse bulkResponse = bulkIndexDocs(testIndex, docs, 30_000); + assertEquals(RestStatus.OK, bulkResponse.status()); + assertFalse(bulkResponse.hasFailures()); + long count = countDocs(testIndex); + assertEquals(totalDocs, count); + } + + public Feature maxValueFeature() throws IOException { + return maxValueFeature(nameField, valueField, nameField); + } + + public Feature maxValueFeature(String aggregationName, String fieldName, String featureName) throws IOException { + AggregationBuilder aggregationBuilder = TestHelpers + .parseAggregation("{\"" + aggregationName + "\":{\"max\":{\"field\":\"" + fieldName + "\"}}}"); + return new Feature(randomAlphaOfLength(5), featureName, true, aggregationBuilder); + } + + public Feature sumValueFeature(String aggregationName, String fieldName, String featureName) throws IOException { + AggregationBuilder aggregationBuilder = TestHelpers + .parseAggregation("{\"" + aggregationName + "\":{\"value_count\":{\"field\":\"" + fieldName + "\"}}}"); + return new Feature(randomAlphaOfLength(5), featureName, true, aggregationBuilder); + } + +} diff --git a/src/test/java/org/opensearch/ad/ADUnitTestCase.java-e b/src/test/java/org/opensearch/ad/ADUnitTestCase.java-e new file mode 100644 index 000000000..232c5dcdc --- /dev/null +++ b/src/test/java/org/opensearch/ad/ADUnitTestCase.java-e @@ -0,0 +1,82 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import static org.opensearch.cluster.node.DiscoveryNodeRole.BUILT_IN_ROLES; + +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.MockitoAnnotations; +import org.opensearch.Version; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.transport.TransportAddress; +import org.opensearch.test.OpenSearchTestCase; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Sets; + +public class ADUnitTestCase extends OpenSearchTestCase { + + @Captor + protected ArgumentCaptor exceptionCaptor; + + @Override + public void setUp() throws Exception { + super.setUp(); + MockitoAnnotations.initMocks(this); + } + + /** + * Create cluster setting. + * + * @param settings cluster settings + * @param setting add setting if the code to be tested contains setting update consumer + * @return instance of ClusterSettings + */ + public ClusterSettings clusterSetting(Settings settings, Setting... setting) { + final Set> settingsSet = Stream + .concat(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), Sets.newHashSet(setting).stream()) + .collect(Collectors.toSet()); + ClusterSettings clusterSettings = new ClusterSettings(settings, settingsSet); + return clusterSettings; + } + + protected DiscoveryNode createNode(String nodeId) { + return new DiscoveryNode( + nodeId, + new TransportAddress(TransportAddress.META_ADDRESS, 9300), + ImmutableMap.of(), + BUILT_IN_ROLES, + Version.CURRENT + ); + } + + protected DiscoveryNode createNode(String nodeId, String ip, int port, Map attributes) throws UnknownHostException { + return new DiscoveryNode( + nodeId, + new TransportAddress(InetAddress.getByName(ip), port), + attributes, + BUILT_IN_ROLES, + Version.CURRENT + ); + } +} diff --git a/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java-e b/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java-e new file mode 100644 index 000000000..f28de4547 --- /dev/null +++ b/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java-e @@ -0,0 +1,187 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.time.Clock; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Optional; +import java.util.Set; +import java.util.function.Consumer; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.opensearch.Version; +import org.opensearch.action.get.GetResponse; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.DetectorProfileName; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.AnomalyResultTests; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.transport.TransportAddress; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.transport.TransportService; + +public class AbstractProfileRunnerTests extends AbstractTimeSeriesTest { + protected enum DetectorStatus { + INDEX_NOT_EXIST, + NO_DOC, + EXIST + } + + protected enum JobStatus { + INDEX_NOT_EXIT, + DISABLED, + ENABLED + } + + protected enum ErrorResultStatus { + INDEX_NOT_EXIT, + NO_ERROR, + SHINGLE_ERROR, + STOPPED_ERROR, + NULL_POINTER_EXCEPTION + } + + protected AnomalyDetectorProfileRunner runner; + protected Client client; + protected SecurityClientUtil clientUtil; + protected DiscoveryNodeFilterer nodeFilter; + protected AnomalyDetector detector; + protected ClusterService clusterService; + protected TransportService transportService; + protected ADTaskManager adTaskManager; + + protected static Set stateOnly; + protected static Set stateNError; + protected static Set modelProfile; + protected static Set stateInitProgress; + protected static Set totalInitProgress; + protected static Set initProgressErrorProfile; + + protected static String noFullShingleError = "No full shingle in current detection window"; + protected static String stoppedError = + "Stopped detector as job failed consecutively for more than 3 times: Having trouble querying data." + + " Maybe all of your features have been disabled."; + + protected static String clusterName; + protected static DiscoveryNode discoveryNode1; + + protected int requiredSamples; + protected int neededSamples; + + // profile model related + protected String node1; + protected String nodeName1; + + protected String node2; + protected String nodeName2; + protected DiscoveryNode discoveryNode2; + + protected long modelSize; + protected String model1Id; + protected String model0Id; + + protected int shingleSize; + + protected int detectorIntervalMin; + protected GetResponse detectorGetReponse; + protected String messaingExceptionError = "blah"; + + @BeforeClass + public static void setUpOnce() { + stateOnly = new HashSet(); + stateOnly.add(DetectorProfileName.STATE); + stateNError = new HashSet(); + stateNError.add(DetectorProfileName.ERROR); + stateNError.add(DetectorProfileName.STATE); + stateInitProgress = new HashSet(); + stateInitProgress.add(DetectorProfileName.INIT_PROGRESS); + stateInitProgress.add(DetectorProfileName.STATE); + modelProfile = new HashSet( + Arrays + .asList( + DetectorProfileName.SHINGLE_SIZE, + DetectorProfileName.MODELS, + DetectorProfileName.COORDINATING_NODE, + DetectorProfileName.TOTAL_SIZE_IN_BYTES + ) + ); + totalInitProgress = new HashSet( + Arrays.asList(DetectorProfileName.TOTAL_ENTITIES, DetectorProfileName.INIT_PROGRESS) + ); + initProgressErrorProfile = new HashSet( + Arrays.asList(DetectorProfileName.INIT_PROGRESS, DetectorProfileName.ERROR) + ); + clusterName = "test-cluster-name"; + discoveryNode1 = new DiscoveryNode( + "nodeName1", + "node1", + new TransportAddress(TransportAddress.META_ADDRESS, 9300), + emptyMap(), + emptySet(), + Version.CURRENT + ); + setUpThreadPool(AnomalyResultTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + + @SuppressWarnings("unchecked") + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + Clock clock = mock(Clock.class); + + nodeFilter = mock(DiscoveryNodeFilterer.class); + clusterService = mock(ClusterService.class); + adTaskManager = mock(ADTaskManager.class); + when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("test cluster")).build()); + + requiredSamples = 128; + neededSamples = 5; + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + Consumer> function = (Consumer>) args[2]; + function.accept(Optional.of(TestHelpers.randomAdTask())); + return null; + }).when(adTaskManager).getAndExecuteOnLatestDetectorLevelTask(any(), any(), any(), any(), anyBoolean(), any()); + + detectorIntervalMin = 3; + detectorGetReponse = mock(GetResponse.class); + + } +} diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java b/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java index 751c7992a..074b6ee86 100644 --- a/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java @@ -67,16 +67,16 @@ import org.opensearch.ad.util.ClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.index.Index; import org.opensearch.index.get.GetResult; -import org.opensearch.index.shard.ShardId; import org.opensearch.jobscheduler.spi.JobExecutionContext; import org.opensearch.jobscheduler.spi.LockModel; import org.opensearch.jobscheduler.spi.ScheduledJobParameter; diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java-e b/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java-e new file mode 100644 index 000000000..4fed79fc0 --- /dev/null +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java-e @@ -0,0 +1,768 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.NUM_MIN_SAMPLES; +import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Locale; +import java.util.Optional; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ThreadFactory; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskCacheManager; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.AnomalyResultAction; +import org.opensearch.ad.transport.AnomalyResultResponse; +import org.opensearch.ad.transport.handler.AnomalyIndexHandler; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.index.Index; +import org.opensearch.index.get.GetResult; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.jobscheduler.spi.JobExecutionContext; +import org.opensearch.jobscheduler.spi.LockModel; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.jobscheduler.spi.schedule.Schedule; +import org.opensearch.jobscheduler.spi.utils.LockService; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +import com.google.common.collect.ImmutableList; + +public class AnomalyDetectorJobRunnerTests extends AbstractTimeSeriesTest { + + @Mock + private Client client; + + @Mock + private ClientUtil clientUtil; + + @Mock + private ClusterService clusterService; + + private LockService lockService; + + @Mock + private AnomalyDetectorJob jobParameter; + + @Mock + private JobExecutionContext context; + + private AnomalyDetectorJobRunner runner = AnomalyDetectorJobRunner.getJobRunnerInstance(); + + @Mock + private ThreadPool mockedThreadPool; + + private ExecutorService executorService; + + @Mock + private Iterator backoff; + + @Mock + private AnomalyIndexHandler anomalyResultHandler; + + @Mock + private ADTaskManager adTaskManager; + + private ExecuteADResultResponseRecorder recorder; + + @Mock + private DiscoveryNodeFilterer nodeFilter; + + private AnomalyDetector detector; + + @Mock + private ADTaskCacheManager adTaskCacheManager; + + @Mock + private NodeStateManager nodeStateManager; + + private ADIndexManagement anomalyDetectionIndices; + + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(AnomalyDetectorJobRunnerTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + + @SuppressWarnings("unchecked") + @Before + public void setup() throws Exception { + super.setUp(); + super.setUpLog4jForJUnit(AnomalyDetectorJobRunner.class); + MockitoAnnotations.initMocks(this); + ThreadFactory threadFactory = OpenSearchExecutors.daemonThreadFactory(OpenSearchExecutors.threadName("node1", "test-ad")); + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + executorService = OpenSearchExecutors.newFixed("test-ad", 4, 100, threadFactory, threadContext); + Mockito.doReturn(executorService).when(mockedThreadPool).executor(anyString()); + Mockito.doReturn(mockedThreadPool).when(client).threadPool(); + Mockito.doReturn(threadContext).when(mockedThreadPool).getThreadContext(); + runner.setThreadPool(mockedThreadPool); + runner.setClient(client); + runner.setAdTaskManager(adTaskManager); + + Settings settings = Settings + .builder() + .put("plugins.anomaly_detection.max_retry_for_backoff", 2) + .put("plugins.anomaly_detection.backoff_initial_delay", TimeValue.timeValueMillis(1)) + .put("plugins.anomaly_detection.max_retry_for_end_run_exception", 3) + .build(); + setUpJobParameter(); + + runner.setSettings(settings); + + anomalyDetectionIndices = mock(ADIndexManagement.class); + + runner.setAnomalyDetectionIndices(anomalyDetectionIndices); + + lockService = new LockService(client, clusterService); + doReturn(lockService).when(context).getLockService(); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + GetRequest request = (GetRequest) args[0]; + ActionListener listener = (ActionListener) args[1]; + + if (request.index().equals(CommonName.JOB_INDEX)) { + AnomalyDetectorJob job = TestHelpers.randomAnomalyDetectorJob(true); + listener.onResponse(TestHelpers.createGetResponse(job, randomAlphaOfLength(5), CommonName.JOB_INDEX)); + } + return null; + }).when(client).get(any(), any()); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length >= 2 + ); + + IndexRequest request = null; + ActionListener listener = null; + if (args[0] instanceof IndexRequest) { + request = (IndexRequest) args[0]; + } + if (args[1] instanceof ActionListener) { + listener = (ActionListener) args[1]; + } + + assertTrue(request != null && listener != null); + ShardId shardId = new ShardId(new Index(CommonName.JOB_INDEX, randomAlphaOfLength(10)), 0); + listener.onResponse(new IndexResponse(shardId, request.id(), 1, 1, 1, true)); + + return null; + }).when(client).index(any(), any()); + + when(adTaskCacheManager.hasQueriedResultIndex(anyString())).thenReturn(false); + + detector = TestHelpers.randomAnomalyDetectorWithEmptyFeature(); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + runner.setNodeStateManager(nodeStateManager); + + recorder = new ExecuteADResultResponseRecorder( + anomalyDetectionIndices, + anomalyResultHandler, + adTaskManager, + nodeFilter, + threadPool, + client, + nodeStateManager, + adTaskCacheManager, + 32 + ); + runner.setExecuteADResultResponseRecorder(recorder); + } + + @Rule + public ExpectedException expectedEx = ExpectedException.none(); + + @Override + public void tearDown() throws Exception { + super.tearDown(); + super.tearDownLog4jForJUnit(); + executorService.shutdown(); + } + + @Test + public void testRunJobWithWrongParameterType() { + expectedEx.expect(IllegalArgumentException.class); + expectedEx.expectMessage("Job parameter is not instance of AnomalyDetectorJob, type: "); + + ScheduledJobParameter parameter = mock(ScheduledJobParameter.class); + when(jobParameter.getLockDurationSeconds()).thenReturn(null); + runner.runJob(parameter, context); + } + + @Test + public void testRunJobWithNullLockDuration() throws InterruptedException { + when(jobParameter.getLockDurationSeconds()).thenReturn(null); + when(jobParameter.getSchedule()).thenReturn(new IntervalSchedule(Instant.now(), 1, ChronoUnit.MINUTES)); + runner.runJob(jobParameter, context); + Thread.sleep(2000); + assertTrue(testAppender.containsMessage("Can't get lock for AD job")); + } + + @Test + public void testRunJobWithLockDuration() throws InterruptedException { + when(jobParameter.getLockDurationSeconds()).thenReturn(100L); + when(jobParameter.getSchedule()).thenReturn(new IntervalSchedule(Instant.now(), 1, ChronoUnit.MINUTES)); + runner.runJob(jobParameter, context); + Thread.sleep(1000); + assertFalse(testAppender.containsMessage("Can't get lock for AD job")); + verify(context, times(1)).getLockService(); + } + + @Test + public void testRunAdJobWithNullLock() { + LockModel lock = null; + runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); + verify(client, never()).execute(any(), any(), any()); + } + + @Test + public void testRunAdJobWithLock() { + LockModel lock = new LockModel("indexName", "jobId", Instant.now(), 10, false); + + runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); + verify(client, times(1)).execute(any(), any(), any()); + } + + @Test + public void testRunAdJobWithExecuteException() { + LockModel lock = new LockModel("indexName", "jobId", Instant.now(), 10, false); + + doThrow(RuntimeException.class).when(client).execute(any(), any(), any()); + + runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); + verify(client, times(1)).execute(any(), any(), any()); + assertTrue(testAppender.containsMessage("Failed to execute AD job")); + } + + @Test + public void testRunAdJobWithEndRunExceptionNow() { + LockModel lock = new LockModel("indexName", "jobId", Instant.now(), 10, false); + Exception exception = new EndRunException(jobParameter.getName(), randomAlphaOfLength(5), true); + runner + .handleAdException( + jobParameter, + lockService, + lock, + Instant.now().minusMillis(1000 * 60), + Instant.now(), + exception, + recorder, + detector + ); + verify(anomalyResultHandler).index(any(), any(), any()); + } + + @Test + public void testRunAdJobWithEndRunExceptionNowAndExistingAdJob() { + testRunAdJobWithEndRunExceptionNowAndStopAdJob(true, true, true); + verify(anomalyResultHandler).index(any(), any(), any()); + verify(client).index(any(IndexRequest.class), any()); + assertTrue(testAppender.containsMessage("EndRunException happened when executing anomaly result action for")); + assertTrue(testAppender.containsMessage("JobRunner will stop AD job due to EndRunException for")); + assertTrue(testAppender.containsMessage("AD Job was disabled by JobRunner for")); + } + + @Test + public void testRunAdJobWithEndRunExceptionNowAndExistingAdJobAndIndexException() { + testRunAdJobWithEndRunExceptionNowAndStopAdJob(true, true, false); + verify(anomalyResultHandler).index(any(), any(), any()); + verify(client).index(any(IndexRequest.class), any()); + assertTrue(testAppender.containsMessage("Failed to disable AD job for")); + } + + @Test + public void testRunAdJobWithEndRunExceptionNowAndNotExistingEnabledAdJob() { + testRunAdJobWithEndRunExceptionNowAndStopAdJob(false, true, true); + verify(client, never()).index(any(), any()); + assertFalse(testAppender.containsMessage("AD Job was disabled by JobRunner for")); + assertFalse(testAppender.containsMessage("Failed to disable AD job for")); + assertTrue(testAppender.containsMessage("AD Job was not found for")); + verify(anomalyResultHandler).index(any(), any(), any()); + verify(adTaskManager).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); + } + + @Test + public void testRunAdJobWithEndRunExceptionNowAndExistingDisabledAdJob() { + testRunAdJobWithEndRunExceptionNowAndStopAdJob(true, false, true); + verify(anomalyResultHandler).index(any(), any(), any()); + verify(client, never()).index(any(), any()); + assertFalse(testAppender.containsMessage("AD Job was not found for")); + assertFalse(testAppender.containsMessage("AD Job was disabled by JobRunner for")); + } + + @Test + public void testRunAdJobWithEndRunExceptionNowAndNotExistingDisabledAdJob() { + testRunAdJobWithEndRunExceptionNowAndStopAdJob(false, false, true); + verify(anomalyResultHandler).index(any(), any(), any()); + verify(client, never()).index(any(), any()); + assertFalse(testAppender.containsMessage("AD Job was disabled by JobRunner for")); + } + + private void testRunAdJobWithEndRunExceptionNowAndStopAdJob(boolean jobExists, boolean jobEnabled, boolean disableSuccessfully) { + LockModel lock = new LockModel(CommonName.JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); + Exception exception = new EndRunException(jobParameter.getName(), randomAlphaOfLength(5), true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + GetResponse response = new GetResponse( + new GetResult( + CommonName.JOB_INDEX, + jobParameter.getName(), + UNASSIGNED_SEQ_NO, + 0, + -1, + jobExists, + BytesReference + .bytes( + new AnomalyDetectorJob( + jobParameter.getName(), + jobParameter.getSchedule(), + jobParameter.getWindowDelay(), + jobEnabled, + Instant.now().minusSeconds(60), + Instant.now(), + Instant.now(), + 60L, + TestHelpers.randomUser(), + jobParameter.getCustomResultIndex() + ).toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS) + ), + Collections.emptyMap(), + Collections.emptyMap() + ) + ); + + listener.onResponse(response); + return null; + }).when(client).get(any(GetRequest.class), any()); + + doAnswer(invocation -> { + IndexRequest request = invocation.getArgument(0); + ActionListener listener = invocation.getArgument(1); + ShardId shardId = new ShardId(new Index(CommonName.JOB_INDEX, randomAlphaOfLength(10)), 0); + if (disableSuccessfully) { + listener.onResponse(new IndexResponse(shardId, request.id(), 1, 1, 1, true)); + } else { + listener.onResponse(null); + } + return null; + }).when(client).index(any(IndexRequest.class), any()); + + runner + .handleAdException( + jobParameter, + lockService, + lock, + Instant.now().minusMillis(1000 * 60), + Instant.now(), + exception, + recorder, + detector + ); + } + + @Test + public void testRunAdJobWithEndRunExceptionNowAndGetJobException() { + LockModel lock = new LockModel("indexName", "jobId", Instant.now(), 10, false); + Exception exception = new EndRunException(jobParameter.getName(), randomAlphaOfLength(5), true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("test")); + return null; + }).when(client).get(any(GetRequest.class), any()); + + runner + .handleAdException( + jobParameter, + lockService, + lock, + Instant.now().minusMillis(1000 * 60), + Instant.now(), + exception, + recorder, + detector + ); + assertTrue(testAppender.containsMessage("JobRunner will stop AD job due to EndRunException for")); + assertTrue(testAppender.containsMessage("JobRunner failed to get detector job")); + verify(anomalyResultHandler).index(any(), any(), any()); + assertEquals(1, testAppender.countMessage("JobRunner failed to get detector job")); + } + + @SuppressWarnings("unchecked") + @Test + public void testRunAdJobWithEndRunExceptionNowAndFailToGetJob() { + LockModel lock = new LockModel("indexName", "jobId", Instant.now(), 10, false); + Exception exception = new EndRunException(jobParameter.getName(), randomAlphaOfLength(5), true); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + GetRequest request = (GetRequest) args[0]; + ActionListener listener = (ActionListener) args[1]; + + if (request.index().equals(CommonName.JOB_INDEX)) { + listener.onFailure(new RuntimeException("fail to get AD job")); + } + return null; + }).when(client).get(any(), any()); + + runner + .handleAdException( + jobParameter, + lockService, + lock, + Instant.now().minusMillis(1000 * 60), + Instant.now(), + exception, + recorder, + detector + ); + verify(anomalyResultHandler).index(any(), any(), any()); + assertEquals(1, testAppender.countMessage("JobRunner failed to get detector job")); + } + + @Test + public void testRunAdJobWithEndRunExceptionNotNowAndRetryUntilStop() throws InterruptedException { + LockModel lock = new LockModel(CommonName.JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); + Instant executionStartTime = Instant.now(); + Schedule schedule = mock(IntervalSchedule.class); + when(jobParameter.getSchedule()).thenReturn(schedule); + when(schedule.getNextExecutionTime(executionStartTime)).thenReturn(executionStartTime.plusSeconds(5)); + + doAnswer(invocation -> { + Exception exception = new EndRunException(jobParameter.getName(), randomAlphaOfLength(5), false); + ActionListener listener = invocation.getArgument(2); + listener.onFailure(exception); + return null; + }).when(client).execute(any(), any(), any()); + + for (int i = 0; i < 3; i++) { + runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + assertEquals(i + 1, testAppender.countMessage("EndRunException happened for")); + } + runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + assertEquals(1, testAppender.countMessage("JobRunner will stop AD job due to EndRunException retry exceeds upper limit")); + } + + private void setUpJobParameter() { + when(jobParameter.getName()).thenReturn(randomAlphaOfLength(10)); + IntervalSchedule schedule = new IntervalSchedule(Instant.now(), 1, ChronoUnit.MINUTES); + when(jobParameter.getSchedule()).thenReturn(schedule); + when(jobParameter.getWindowDelay()).thenReturn(new IntervalTimeConfiguration(10, ChronoUnit.SECONDS)); + } + + /** + * Test updateLatestRealtimeTask.confirmTotalRCFUpdatesFound + * @throws InterruptedException + */ + public Instant confirmInitializedSetup() { + // clear the appender created in setUp before creating another association; otherwise + // we will have unexpected error (e.g., some appender does not record messages even + // though we have configured to do so). + super.tearDownLog4jForJUnit(); + setUpLog4jForJUnit(ExecuteADResultResponseRecorder.class, true); + Schedule schedule = mock(IntervalSchedule.class); + when(jobParameter.getSchedule()).thenReturn(schedule); + Instant executionStartTime = Instant.now(); + when(schedule.getNextExecutionTime(executionStartTime)).thenReturn(executionStartTime.plusSeconds(5)); + + AnomalyResultResponse response = new AnomalyResultResponse( + 4d, + 0.993, + 1.01, + Collections.singletonList(new FeatureData("123", "abc", 0d)), + randomAlphaOfLength(4), + // not fully initialized + Long.valueOf(AnomalyDetectorSettings.NUM_MIN_SAMPLES - 1), + randomLong(), + // not an HC detector + false, + randomInt(), + new double[] { randomDoubleBetween(0, 1.0, true), randomDoubleBetween(0, 1.0, true) }, + new double[] { randomDouble(), randomDouble() }, + new double[][] { new double[] { randomDouble(), randomDouble() } }, + new double[] { randomDouble() }, + randomDoubleBetween(1.1, 10.0, true) + ); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(response); + return null; + }).when(client).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); + return executionStartTime; + } + + @SuppressWarnings("unchecked") + public void testFailtoFindDetector() { + Instant executionStartTime = confirmInitializedSetup(); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException()); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + + LockModel lock = new LockModel(CommonName.JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); + + runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + + verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); + verify(adTaskCacheManager, times(1)).hasQueriedResultIndex(anyString()); + verify(nodeStateManager, times(1)).getAnomalyDetector(any(String.class), any(ActionListener.class)); + verify(nodeStateManager, times(0)).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); + assertEquals(1, testAppender.countMessage("Fail to confirm rcf update")); + assertTrue(testAppender.containExceptionMsg(TimeSeriesException.class, "fail to get detector")); + } + + @SuppressWarnings("unchecked") + public void testFailtoFindJob() { + Instant executionStartTime = confirmInitializedSetup(); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException()); + return null; + }).when(nodeStateManager).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + + LockModel lock = new LockModel(CommonName.JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); + + runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + + verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); + verify(adTaskCacheManager, times(1)).hasQueriedResultIndex(anyString()); + verify(nodeStateManager, times(1)).getAnomalyDetector(any(String.class), any(ActionListener.class)); + verify(nodeStateManager, times(1)).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); + assertEquals(1, testAppender.countMessage("Fail to confirm rcf update")); + assertTrue(testAppender.containExceptionMsg(TimeSeriesException.class, "fail to get job")); + } + + @SuppressWarnings("unchecked") + public void testEmptyDetector() { + Instant executionStartTime = confirmInitializedSetup(); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.empty()); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + + LockModel lock = new LockModel(CommonName.JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); + + runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + + verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); + verify(adTaskCacheManager, times(1)).hasQueriedResultIndex(anyString()); + verify(nodeStateManager, times(1)).getAnomalyDetector(any(String.class), any(ActionListener.class)); + verify(nodeStateManager, times(0)).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); + assertEquals(1, testAppender.countMessage("Fail to confirm rcf update")); + assertTrue(testAppender.containExceptionMsg(TimeSeriesException.class, "fail to get detector")); + } + + @SuppressWarnings("unchecked") + public void testEmptyJob() { + Instant executionStartTime = confirmInitializedSetup(); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.empty()); + return null; + }).when(nodeStateManager).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + + LockModel lock = new LockModel(CommonName.JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); + + runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + + verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); + verify(adTaskCacheManager, times(1)).hasQueriedResultIndex(anyString()); + verify(nodeStateManager, times(1)).getAnomalyDetector(any(String.class), any(ActionListener.class)); + verify(nodeStateManager, times(1)).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); + assertEquals(1, testAppender.countMessage("Fail to confirm rcf update")); + assertTrue(testAppender.containExceptionMsg(TimeSeriesException.class, "fail to get job")); + } + + @SuppressWarnings("unchecked") + public void testMarkResultIndexQueried() throws IOException { + detector = TestHelpers.AnomalyDetectorBuilder + .newInstance() + .setDetectionInterval(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES)) + .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) + .setResultIndex(ADCommonName.CUSTOM_RESULT_INDEX_PREFIX + "index") + .build(); + Instant executionStartTime = confirmInitializedSetup(); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(TestHelpers.randomAnomalyDetectorJob(true, Instant.ofEpochMilli(1602401500000L), null))); + return null; + }).when(nodeStateManager).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + + ActionListener listener = (ActionListener) args[1]; + + SearchResponse mockResponse = mock(SearchResponse.class); + int totalHits = 1001; + when(mockResponse.getHits()).thenReturn(TestHelpers.createSearchHits(totalHits)); + + listener.onResponse(mockResponse); + + return null; + }).when(client).search(any(), any(ActionListener.class)); + + // use a unmocked adTaskCacheManager to test the value of hasQueriedResultIndex has changed + Settings settings = Settings + .builder() + .put(AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE.getKey(), 2) + .put(AnomalyDetectorSettings.MAX_CACHED_DELETED_TASKS.getKey(), 100) + .build(); + + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + settings, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays.asList(AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE, AnomalyDetectorSettings.MAX_CACHED_DELETED_TASKS) + ) + ) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + MemoryTracker memoryTracker = mock(MemoryTracker.class); + adTaskCacheManager = new ADTaskCacheManager(settings, clusterService, memoryTracker); + + // init real time task cache for the detector. We will do this during AnomalyResultTransportAction. + // Since we mocked the execution by returning anomaly result directly, we need to init it explicitly. + adTaskCacheManager.initRealtimeTaskCache(detector.getId(), 0); + + // recreate recorder since we need to use the unmocked adTaskCacheManager + recorder = new ExecuteADResultResponseRecorder( + anomalyDetectionIndices, + anomalyResultHandler, + adTaskManager, + nodeFilter, + threadPool, + client, + nodeStateManager, + adTaskCacheManager, + 32 + ); + + assertEquals(false, adTaskCacheManager.hasQueriedResultIndex(detector.getId())); + + LockModel lock = new LockModel(CommonName.JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); + + runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + + verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); + verify(client, times(1)).search(any(), any()); + verify(nodeStateManager, times(1)).getAnomalyDetector(any(String.class), any(ActionListener.class)); + verify(nodeStateManager, times(1)).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + + ArgumentCaptor totalUpdates = ArgumentCaptor.forClass(Long.class); + verify(adTaskManager, times(1)) + .updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), totalUpdates.capture(), any(), any(), any()); + assertEquals(NUM_MIN_SAMPLES, totalUpdates.getValue().longValue()); + assertEquals(true, adTaskCacheManager.hasQueriedResultIndex(detector.getId())); + } +} diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java b/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java index f786be6b8..5d3c54541 100644 --- a/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java @@ -56,9 +56,9 @@ import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.io.stream.NotSerializableExceptionWrapper; import org.opensearch.common.settings.Settings; import org.opensearch.common.transport.TransportAddress; +import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; import org.opensearch.index.IndexNotFoundException; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java-e b/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java-e new file mode 100644 index 000000000..737c6c684 --- /dev/null +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java-e @@ -0,0 +1,653 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.*; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; + +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.model.DetectorInternalState; +import org.opensearch.ad.model.DetectorProfile; +import org.opensearch.ad.model.DetectorProfileName; +import org.opensearch.ad.model.DetectorState; +import org.opensearch.ad.model.InitProgressProfile; +import org.opensearch.ad.model.ModelProfileOnNode; +import org.opensearch.ad.transport.ProfileAction; +import org.opensearch.ad.transport.ProfileNodeResponse; +import org.opensearch.ad.transport.ProfileResponse; +import org.opensearch.ad.transport.RCFPollingAction; +import org.opensearch.ad.transport.RCFPollingResponse; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.transport.TransportAddress; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.transport.RemoteTransportException; + +public class AnomalyDetectorProfileRunnerTests extends AbstractProfileRunnerTests { + enum RCFPollingStatus { + INIT_NOT_EXIT, + REMOTE_INIT_NOT_EXIT, + INDEX_NOT_FOUND, + REMOTE_INDEX_NOT_FOUND, + INIT_DONE, + EMPTY, + EXCEPTION, + INITTING + } + + private Instant jobEnabledTime = Instant.now().minus(1, ChronoUnit.DAYS); + + /** + * Convenience methods for single-stream detector profile tests set up + * @param detectorStatus Detector config status + * @param jobStatus Detector job status + * @param rcfPollingStatus RCF polling result status + * @param errorResultStatus Error result status + * @throws IOException when failing the getting request + */ + @SuppressWarnings("unchecked") + private void setUpClientGet( + DetectorStatus detectorStatus, + JobStatus jobStatus, + RCFPollingStatus rcfPollingStatus, + ErrorResultStatus errorResultStatus + ) throws IOException { + detector = TestHelpers.randomAnomalyDetectorWithInterval(new IntervalTimeConfiguration(detectorIntervalMin, ChronoUnit.MINUTES)); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(anyString(), any(ActionListener.class)); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); + runner = new AnomalyDetectorProfileRunner( + client, + clientUtil, + xContentRegistry(), + nodeFilter, + requiredSamples, + transportService, + adTaskManager + ); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + GetRequest request = (GetRequest) args[0]; + ActionListener listener = (ActionListener) args[1]; + + if (request.index().equals(CommonName.CONFIG_INDEX)) { + switch (detectorStatus) { + case EXIST: + listener.onResponse(TestHelpers.createGetResponse(detector, detector.getId(), CommonName.CONFIG_INDEX)); + break; + case INDEX_NOT_EXIST: + listener.onFailure(new IndexNotFoundException(CommonName.CONFIG_INDEX)); + break; + case NO_DOC: + when(detectorGetReponse.isExists()).thenReturn(false); + listener.onResponse(detectorGetReponse); + break; + default: + assertTrue("should not reach here", false); + break; + } + } else if (request.index().equals(CommonName.JOB_INDEX)) { + AnomalyDetectorJob job = null; + switch (jobStatus) { + case INDEX_NOT_EXIT: + listener.onFailure(new IndexNotFoundException(CommonName.JOB_INDEX)); + break; + case DISABLED: + job = TestHelpers.randomAnomalyDetectorJob(false, jobEnabledTime, null); + listener.onResponse(TestHelpers.createGetResponse(job, detector.getId(), CommonName.JOB_INDEX)); + break; + case ENABLED: + job = TestHelpers.randomAnomalyDetectorJob(true, jobEnabledTime, null); + listener.onResponse(TestHelpers.createGetResponse(job, detector.getId(), CommonName.JOB_INDEX)); + break; + default: + assertTrue("should not reach here", false); + break; + } + } else { + if (errorResultStatus == ErrorResultStatus.INDEX_NOT_EXIT) { + listener.onFailure(new IndexNotFoundException(ADCommonName.DETECTION_STATE_INDEX)); + return null; + } + DetectorInternalState.Builder result = new DetectorInternalState.Builder().lastUpdateTime(Instant.now()); + + String error = getError(errorResultStatus); + if (error != null) { + result.error(error); + } + listener.onResponse(TestHelpers.createGetResponse(result.build(), detector.getId(), ADCommonName.DETECTION_STATE_INDEX)); + + } + + return null; + }).when(client).get(any(), any()); + + setUpClientExecuteRCFPollingAction(rcfPollingStatus); + } + + private String getError(ErrorResultStatus errorResultStatus) { + switch (errorResultStatus) { + case NO_ERROR: + break; + case SHINGLE_ERROR: + return noFullShingleError; + case STOPPED_ERROR: + return stoppedError; + default: + assertTrue("should not reach here", false); + break; + } + return null; + } + + public void testDetectorNotExist() throws IOException, InterruptedException { + setUpClientGet(DetectorStatus.INDEX_NOT_EXIST, JobStatus.INDEX_NOT_EXIT, RCFPollingStatus.EMPTY, ErrorResultStatus.NO_ERROR); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + runner.profile("x123", ActionListener.wrap(response -> { + assertTrue("Should not reach here", false); + inProgressLatch.countDown(); + }, exception -> { + assertTrue(exception.getMessage().contains(CommonMessages.FAIL_TO_FIND_CONFIG_MSG)); + inProgressLatch.countDown(); + }), stateNError); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + public void testDisabledJobIndexTemplate(JobStatus status) throws IOException, InterruptedException { + setUpClientGet(DetectorStatus.EXIST, status, RCFPollingStatus.EMPTY, ErrorResultStatus.NO_ERROR); + DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.DISABLED).build(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + runner.profile(detector.getId(), ActionListener.wrap(response -> { + assertEquals(expectedProfile, response); + inProgressLatch.countDown(); + }, exception -> { + assertTrue("Should not reach here ", false); + inProgressLatch.countDown(); + }), stateOnly); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + public void testNoJobIndex() throws IOException, InterruptedException { + testDisabledJobIndexTemplate(JobStatus.INDEX_NOT_EXIT); + } + + public void testJobDisabled() throws IOException, InterruptedException { + testDisabledJobIndexTemplate(JobStatus.DISABLED); + } + + public void testInitOrRunningStateTemplate(RCFPollingStatus status, DetectorState expectedState) throws IOException, + InterruptedException { + setUpClientGet(DetectorStatus.EXIST, JobStatus.ENABLED, status, ErrorResultStatus.NO_ERROR); + DetectorProfile expectedProfile = new DetectorProfile.Builder().state(expectedState).build(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + runner.profile(detector.getId(), ActionListener.wrap(response -> { + assertEquals(expectedProfile, response); + inProgressLatch.countDown(); + }, exception -> { + logger.error(exception); + for (StackTraceElement ste : exception.getStackTrace()) { + logger.info(ste); + } + assertTrue("Should not reach here ", false); + inProgressLatch.countDown(); + }), stateOnly); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + public void testResultNotExist() throws IOException, InterruptedException { + testInitOrRunningStateTemplate(RCFPollingStatus.INIT_NOT_EXIT, DetectorState.INIT); + } + + public void testRemoteResultNotExist() throws IOException, InterruptedException { + testInitOrRunningStateTemplate(RCFPollingStatus.REMOTE_INIT_NOT_EXIT, DetectorState.INIT); + } + + public void testCheckpointIndexNotExist() throws IOException, InterruptedException { + testInitOrRunningStateTemplate(RCFPollingStatus.INDEX_NOT_FOUND, DetectorState.INIT); + } + + public void testRemoteCheckpointIndexNotExist() throws IOException, InterruptedException { + testInitOrRunningStateTemplate(RCFPollingStatus.REMOTE_INDEX_NOT_FOUND, DetectorState.INIT); + } + + public void testResultEmpty() throws IOException, InterruptedException { + testInitOrRunningStateTemplate(RCFPollingStatus.EMPTY, DetectorState.INIT); + } + + public void testResultGreaterThanZero() throws IOException, InterruptedException { + testInitOrRunningStateTemplate(RCFPollingStatus.INIT_DONE, DetectorState.RUNNING); + } + + @SuppressWarnings("unchecked") + public void testErrorStateTemplate( + RCFPollingStatus initStatus, + ErrorResultStatus status, + DetectorState state, + String error, + JobStatus jobStatus, + Set profilesToCollect + ) throws IOException, + InterruptedException { + ADTask adTask = TestHelpers.randomAdTask(); + + adTask.setError(getError(status)); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + Consumer> function = (Consumer>) args[2]; + function.accept(Optional.of(adTask)); + return null; + }).when(adTaskManager).getAndExecuteOnLatestDetectorLevelTask(any(), any(), any(), any(), anyBoolean(), any()); + + setUpClientExecuteRCFPollingAction(initStatus); + setUpClientGet(DetectorStatus.EXIST, jobStatus, initStatus, status); + DetectorProfile.Builder builder = new DetectorProfile.Builder(); + if (profilesToCollect.contains(DetectorProfileName.STATE)) { + builder.state(state); + } + if (profilesToCollect.contains(DetectorProfileName.ERROR)) { + builder.error(error); + } + DetectorProfile expectedProfile = builder.build(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + runner.profile(detector.getId(), ActionListener.wrap(response -> { + assertEquals(expectedProfile, response); + inProgressLatch.countDown(); + }, exception -> { + logger.info(exception); + for (StackTraceElement ste : exception.getStackTrace()) { + logger.info(ste); + } + assertTrue("Should not reach here", false); + inProgressLatch.countDown(); + }), profilesToCollect); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + public void testErrorStateTemplate( + RCFPollingStatus initStatus, + ErrorResultStatus status, + DetectorState state, + String error, + JobStatus jobStatus + ) throws IOException, + InterruptedException { + testErrorStateTemplate(initStatus, status, state, error, jobStatus, stateNError); + } + + public void testRunningNoError() throws IOException, InterruptedException { + testErrorStateTemplate(RCFPollingStatus.INIT_DONE, ErrorResultStatus.NO_ERROR, DetectorState.RUNNING, null, JobStatus.ENABLED); + } + + public void testRunningWithError() throws IOException, InterruptedException { + testErrorStateTemplate( + RCFPollingStatus.INIT_DONE, + ErrorResultStatus.SHINGLE_ERROR, + DetectorState.RUNNING, + noFullShingleError, + JobStatus.ENABLED + ); + } + + public void testDisabledForStateError() throws IOException, InterruptedException { + testErrorStateTemplate( + RCFPollingStatus.INITTING, + ErrorResultStatus.STOPPED_ERROR, + DetectorState.DISABLED, + stoppedError, + JobStatus.DISABLED + ); + } + + public void testDisabledForStateInit() throws IOException, InterruptedException { + testErrorStateTemplate( + RCFPollingStatus.INITTING, + ErrorResultStatus.STOPPED_ERROR, + DetectorState.DISABLED, + stoppedError, + JobStatus.DISABLED, + stateInitProgress + ); + } + + public void testInitWithError() throws IOException, InterruptedException { + testErrorStateTemplate( + RCFPollingStatus.EMPTY, + ErrorResultStatus.SHINGLE_ERROR, + DetectorState.INIT, + noFullShingleError, + JobStatus.ENABLED + ); + } + + @SuppressWarnings("unchecked") + private void setUpClientExecuteProfileAction() { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[2]; + + node1 = "node1"; + nodeName1 = "nodename1"; + discoveryNode1 = new DiscoveryNode( + nodeName1, + node1, + new TransportAddress(TransportAddress.META_ADDRESS, 9300), + emptyMap(), + emptySet(), + Version.CURRENT + ); + + node2 = "node2"; + nodeName2 = "nodename2"; + discoveryNode2 = new DiscoveryNode( + nodeName2, + node2, + new TransportAddress(TransportAddress.META_ADDRESS, 9301), + emptyMap(), + emptySet(), + Version.CURRENT + ); + + modelSize = 4456448L; + model1Id = "Pl536HEBnXkDrah03glg_model_rcf_1"; + model0Id = "Pl536HEBnXkDrah03glg_model_rcf_0"; + + shingleSize = 6; + + String clusterName = "test-cluster-name"; + + Map modelSizeMap1 = new HashMap() { + { + put(model1Id, modelSize); + } + }; + + Map modelSizeMap2 = new HashMap() { + { + put(model0Id, modelSize); + } + }; + + ProfileNodeResponse profileNodeResponse1 = new ProfileNodeResponse( + discoveryNode1, + modelSizeMap1, + shingleSize, + 0L, + 0L, + new ArrayList<>(), + modelSizeMap1.size() + ); + ProfileNodeResponse profileNodeResponse2 = new ProfileNodeResponse( + discoveryNode2, + modelSizeMap2, + -1, + 0L, + 0L, + new ArrayList<>(), + modelSizeMap2.size() + ); + List profileNodeResponses = Arrays.asList(profileNodeResponse1, profileNodeResponse2); + List failures = Collections.emptyList(); + ProfileResponse profileResponse = new ProfileResponse(new ClusterName(clusterName), profileNodeResponses, failures); + + listener.onResponse(profileResponse); + + return null; + }).when(client).execute(any(ProfileAction.class), any(), any()); + + } + + @SuppressWarnings("unchecked") + private void setUpClientExecuteRCFPollingAction(RCFPollingStatus inittedEverResultStatus) { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[2]; + + Exception cause = null; + String detectorId = "123"; + if (inittedEverResultStatus == RCFPollingStatus.INIT_NOT_EXIT + || inittedEverResultStatus == RCFPollingStatus.REMOTE_INIT_NOT_EXIT + || inittedEverResultStatus == RCFPollingStatus.INDEX_NOT_FOUND + || inittedEverResultStatus == RCFPollingStatus.REMOTE_INDEX_NOT_FOUND) { + switch (inittedEverResultStatus) { + case INIT_NOT_EXIT: + case REMOTE_INIT_NOT_EXIT: + cause = new ResourceNotFoundException(detectorId, messaingExceptionError); + break; + case INDEX_NOT_FOUND: + case REMOTE_INDEX_NOT_FOUND: + cause = new IndexNotFoundException(detectorId, ADCommonName.CHECKPOINT_INDEX_NAME); + break; + default: + assertTrue("should not reach here", false); + break; + } + cause = new TimeSeriesException(detectorId, cause); + if (inittedEverResultStatus == RCFPollingStatus.REMOTE_INIT_NOT_EXIT + || inittedEverResultStatus == RCFPollingStatus.REMOTE_INDEX_NOT_FOUND) { + cause = new RemoteTransportException(RCFPollingAction.NAME, new NotSerializableExceptionWrapper(cause)); + } + listener.onFailure(cause); + } else { + RCFPollingResponse result = null; + switch (inittedEverResultStatus) { + case INIT_DONE: + result = new RCFPollingResponse(requiredSamples + 1); + break; + case INITTING: + result = new RCFPollingResponse(requiredSamples - neededSamples); + break; + case EMPTY: + result = new RCFPollingResponse(0); + break; + case EXCEPTION: + listener.onFailure(new RuntimeException()); + break; + default: + assertTrue("should not reach here", false); + break; + } + + listener.onResponse(result); + } + return null; + }).when(client).execute(any(RCFPollingAction.class), any(), any()); + + } + + public void testProfileModels() throws InterruptedException, IOException { + setUpClientGet(DetectorStatus.EXIST, JobStatus.ENABLED, RCFPollingStatus.EMPTY, ErrorResultStatus.NO_ERROR); + setUpClientExecuteProfileAction(); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + runner.profile(detector.getId(), ActionListener.wrap(profileResponse -> { + assertEquals(node1, profileResponse.getCoordinatingNode()); + assertEquals(shingleSize, profileResponse.getShingleSize()); + assertEquals(modelSize * 2, profileResponse.getTotalSizeInBytes()); + assertEquals(2, profileResponse.getModelProfile().length); + for (ModelProfileOnNode profile : profileResponse.getModelProfile()) { + assertTrue(node1.equals(profile.getNodeId()) || node2.equals(profile.getNodeId())); + assertEquals(modelSize, profile.getModelSize()); + if (node1.equals(profile.getNodeId())) { + assertEquals(model1Id, profile.getModelId()); + } + if (node2.equals(profile.getNodeId())) { + assertEquals(model0Id, profile.getModelId()); + } + } + inProgressLatch.countDown(); + }, exception -> { + assertTrue("Should not reach here ", false); + inProgressLatch.countDown(); + }), modelProfile); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + public void testInitProgress() throws IOException, InterruptedException { + setUpClientGet(DetectorStatus.EXIST, JobStatus.ENABLED, RCFPollingStatus.INITTING, ErrorResultStatus.NO_ERROR); + DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.INIT).build(); + + // 123 / 128 rounded to 96% + InitProgressProfile profile = new InitProgressProfile("96%", neededSamples * detectorIntervalMin, neededSamples); + expectedProfile.setInitProgress(profile); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + runner.profile(detector.getId(), ActionListener.wrap(response -> { + assertEquals(expectedProfile, response); + inProgressLatch.countDown(); + }, exception -> { + assertTrue("Should not reach here ", false); + inProgressLatch.countDown(); + }), stateInitProgress); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + public void testInitProgressFailImmediately() throws IOException, InterruptedException { + setUpClientGet(DetectorStatus.NO_DOC, JobStatus.ENABLED, RCFPollingStatus.INITTING, ErrorResultStatus.NO_ERROR); + DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.INIT).build(); + + // 123 / 128 rounded to 96% + InitProgressProfile profile = new InitProgressProfile("96%", neededSamples * detectorIntervalMin, neededSamples); + expectedProfile.setInitProgress(profile); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + runner.profile(detector.getId(), ActionListener.wrap(response -> { + assertTrue("Should not reach here ", false); + inProgressLatch.countDown(); + }, exception -> { + assertTrue(exception.getMessage().contains(CommonMessages.FAIL_TO_FIND_CONFIG_MSG)); + inProgressLatch.countDown(); + }), stateInitProgress); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + public void testInitNoUpdateNoIndex() throws IOException, InterruptedException { + setUpClientGet(DetectorStatus.EXIST, JobStatus.ENABLED, RCFPollingStatus.EMPTY, ErrorResultStatus.NO_ERROR); + DetectorProfile expectedProfile = new DetectorProfile.Builder() + .state(DetectorState.INIT) + .initProgress(new InitProgressProfile("0%", detectorIntervalMin * requiredSamples, requiredSamples)) + .build(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + runner.profile(detector.getId(), ActionListener.wrap(response -> { + assertEquals(expectedProfile, response); + inProgressLatch.countDown(); + }, exception -> { + LOG.error(exception); + for (StackTraceElement ste : exception.getStackTrace()) { + LOG.info(ste); + } + assertTrue("Should not reach here ", false); + inProgressLatch.countDown(); + }), stateInitProgress); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + public void testInitNoIndex() throws IOException, InterruptedException { + setUpClientGet(DetectorStatus.EXIST, JobStatus.ENABLED, RCFPollingStatus.INDEX_NOT_FOUND, ErrorResultStatus.NO_ERROR); + DetectorProfile expectedProfile = new DetectorProfile.Builder() + .state(DetectorState.INIT) + .initProgress(new InitProgressProfile("0%", 0, requiredSamples)) + .build(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + runner.profile(detector.getId(), ActionListener.wrap(response -> { + assertEquals(expectedProfile, response); + inProgressLatch.countDown(); + }, exception -> { + LOG.error(exception); + for (StackTraceElement ste : exception.getStackTrace()) { + LOG.info(ste); + } + assertTrue("Should not reach here ", false); + inProgressLatch.countDown(); + }), stateInitProgress); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + public void testInvalidRequiredSamples() { + expectThrows( + IllegalArgumentException.class, + () -> new AnomalyDetectorProfileRunner(client, clientUtil, xContentRegistry(), nodeFilter, 0, transportService, adTaskManager) + ); + } + + public void testFailRCFPolling() throws IOException, InterruptedException { + setUpClientGet(DetectorStatus.EXIST, JobStatus.ENABLED, RCFPollingStatus.EXCEPTION, ErrorResultStatus.NO_ERROR); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + runner.profile(detector.getId(), ActionListener.wrap(response -> { + assertTrue("Should not reach here ", false); + inProgressLatch.countDown(); + }, exception -> { + assertTrue(exception instanceof RuntimeException); + // this means we don't exit with failImmediately. failImmediately can make we return early when there are other concurrent + // requests + assertTrue(exception.getMessage(), exception.getMessage().contains("Exceptions:")); + inProgressLatch.countDown(); + }), stateNError); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + public void testInitProgressProfile() { + InitProgressProfile progressOne = new InitProgressProfile("0%", 0, requiredSamples); + InitProgressProfile progressTwo = new InitProgressProfile("0%", 0, requiredSamples); + InitProgressProfile progressThree = new InitProgressProfile("96%", 2, requiredSamples); + assertTrue(progressOne.equals(progressTwo)); + assertFalse(progressOne.equals(progressThree)); + } +} diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java b/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java index c5827e0c5..6ff4d604d 100644 --- a/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java @@ -34,14 +34,14 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.common.xcontent.XContentParserUtils; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.rest.RestStatus; +import org.opensearch.core.xcontent.XContentParserUtils; import org.opensearch.test.rest.OpenSearchRestTestCase; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.model.DateRange; diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java-e b/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java-e new file mode 100644 index 000000000..a38915577 --- /dev/null +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java-e @@ -0,0 +1,665 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; + +import java.io.IOException; +import java.io.InputStream; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Locale; +import java.util.Map; + +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorExecutionInput; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.client.RestClient; +import org.opensearch.common.Strings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentParserUtils; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.test.rest.OpenSearchRestTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.util.RestHandlerUtils; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.gson.JsonArray; + +public abstract class AnomalyDetectorRestTestCase extends ODFERestTestCase { + + public static final int MAX_RETRY_TIMES = 10; + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(ImmutableList.of(AnomalyDetector.XCONTENT_REGISTRY)); + } + + @Override + protected Settings restClientSettings() { + return super.restClientSettings(); + } + + protected AnomalyDetector createRandomAnomalyDetector(Boolean refresh, Boolean withMetadata, String indexName, RestClient client) + throws IOException { + return createRandomAnomalyDetector(refresh, withMetadata, client, true, indexName); + } + + protected AnomalyDetector createRandomAnomalyDetector(Boolean refresh, Boolean withMetadata, RestClient client) throws IOException { + return createRandomAnomalyDetector(refresh, withMetadata, client, true); + } + + protected AnomalyDetector createRandomAnomalyDetector(Boolean refresh, Boolean withMetadata, RestClient client, boolean featureEnabled) + throws IOException { + return createRandomAnomalyDetector(refresh, withMetadata, client, featureEnabled, null); + } + + protected AnomalyDetector createRandomAnomalyDetector( + Boolean refresh, + Boolean withMetadata, + RestClient client, + boolean featureEnabled, + String indexName + ) throws IOException { + Map uiMetadata = null; + if (withMetadata) { + uiMetadata = TestHelpers.randomUiMetadata(); + } + + AnomalyDetector detector = null; + + if (indexName == null) { + detector = TestHelpers.randomAnomalyDetector(uiMetadata, null, featureEnabled); + TestHelpers.createIndexWithTimeField(client(), detector.getIndices().get(0), detector.getTimeField()); + TestHelpers + .makeRequest( + client, + "POST", + "/" + detector.getIndices().get(0) + "/_doc/" + randomAlphaOfLength(5) + "?refresh=true", + ImmutableMap.of(), + // avoid validation error as validation API will check at least 1 document and the timestamp field + // exists in index mapping + TestHelpers.toHttpEntity("{\"name\": \"test\", \"" + detector.getTimeField() + "\" : \"1661386754000\"}"), + null, + false + ); + } else { + detector = TestHelpers + .randomAnomalyDetector( + ImmutableList.of(indexName), + ImmutableList.of(TestHelpers.randomFeature(featureEnabled)), + uiMetadata, + Instant.now(), + OpenSearchRestTestCase.randomLongBetween(1, 1000), + true, + null + ); + } + + AnomalyDetector createdDetector = createAnomalyDetector(detector, refresh, client); + + if (withMetadata) { + return getAnomalyDetector(createdDetector.getId(), new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"), client); + } + return getAnomalyDetector(createdDetector.getId(), new BasicHeader(HttpHeaders.CONTENT_TYPE, "application/json"), client); + } + + protected AnomalyDetector createAnomalyDetector(AnomalyDetector detector, Boolean refresh, RestClient client) throws IOException { + Response response = TestHelpers + .makeRequest(client, "POST", TestHelpers.AD_BASE_DETECTORS_URI, ImmutableMap.of(), TestHelpers.toHttpEntity(detector), null); + assertEquals("Create anomaly detector failed", RestStatus.CREATED, TestHelpers.restStatus(response)); + + Map detectorJson = jsonXContent + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, response.getEntity().getContent()) + .map(); + String detectorId = (String) detectorJson.get("_id"); + AnomalyDetector detectorInIndex = null; + int i = 0; + do { + i++; + try { + detectorInIndex = getAnomalyDetector(detectorId, client); + assertNotNull(detectorInIndex); + break; + } catch (Exception e) { + try { + Thread.sleep(2000); + } catch (InterruptedException ex) { + logger.error("Failed to sleep after creating detector", ex); + } + } + } while (i < MAX_RETRY_TIMES); + assertNotNull("Can't get anomaly detector from index", detectorInIndex); + // Adding additional sleep time in order to have more time between AD Creation and whichever + // step comes next in terms of accessing/update/deleting the detector, this will help avoid + // lots of flaky tests + try { + Thread.sleep(5000); + } catch (InterruptedException ex) { + logger.error("Failed to sleep after creating detector", ex); + } + return detectorInIndex; + } + + protected Response startAnomalyDetector(String detectorId, DateRange dateRange, RestClient client) throws IOException { + return TestHelpers + .makeRequest( + client, + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detectorId + "/_start", + ImmutableMap.of(), + dateRange == null ? null : TestHelpers.toHttpEntity(dateRange), + null + ); + } + + protected Response stopAnomalyDetector(String detectorId, RestClient client, boolean realtime) throws IOException { + String jobType = realtime ? "" : "?historical"; + return TestHelpers + .makeRequest( + client, + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detectorId + "/_stop" + jobType, + ImmutableMap.of(), + "", + null + ); + } + + protected Response deleteAnomalyDetector(String detectorId, RestClient client) throws IOException { + return TestHelpers.makeRequest(client, "DELETE", TestHelpers.AD_BASE_DETECTORS_URI + "/" + detectorId, ImmutableMap.of(), "", null); + } + + protected Response previewAnomalyDetector(String detectorId, RestClient client, AnomalyDetectorExecutionInput input) + throws IOException { + return TestHelpers + .makeRequest( + client, + "POST", + String.format(Locale.ROOT, TestHelpers.AD_BASE_PREVIEW_URI, input.getDetectorId()), + ImmutableMap.of(), + TestHelpers.toHttpEntity(input), + null + ); + } + + public AnomalyDetector getAnomalyDetector(String detectorId, RestClient client) throws IOException { + return (AnomalyDetector) getAnomalyDetector(detectorId, false, client)[0]; + } + + public Response updateAnomalyDetector(String detectorId, AnomalyDetector newDetector, RestClient client) throws IOException { + BasicHeader header = new BasicHeader(HttpHeaders.CONTENT_TYPE, "application/json"); + return TestHelpers + .makeRequest( + client, + "PUT", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detectorId, + null, + TestHelpers.toJsonString(newDetector), + ImmutableList.of(header) + ); + } + + public AnomalyDetector getAnomalyDetector(String detectorId, BasicHeader header, RestClient client) throws IOException { + return (AnomalyDetector) getAnomalyDetector(detectorId, header, false, false, client)[0]; + } + + public ToXContentObject[] getAnomalyDetector(String detectorId, boolean returnJob, RestClient client) throws IOException { + BasicHeader header = new BasicHeader(HttpHeaders.CONTENT_TYPE, "application/json"); + return getAnomalyDetector(detectorId, header, returnJob, false, client); + } + + public ToXContentObject[] getAnomalyDetector( + String detectorId, + BasicHeader header, + boolean returnJob, + boolean returnTask, + RestClient client + ) throws IOException { + Response response = TestHelpers + .makeRequest( + client, + "GET", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detectorId + "?job=" + returnJob + "&task=" + returnTask, + null, + "", + ImmutableList.of(header) + ); + assertEquals("Unable to get anomaly detector " + detectorId, RestStatus.OK, TestHelpers.restStatus(response)); + XContentParser parser = createAdParser(XContentType.JSON.xContent(), response.getEntity().getContent()); + parser.nextToken(); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + + String id = null; + Long version = null; + AnomalyDetector detector = null; + AnomalyDetectorJob detectorJob = null; + ADTask realtimeAdTask = null; + ADTask historicalAdTask = null; + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case "_id": + id = parser.text(); + break; + case "_version": + version = parser.longValue(); + break; + case "anomaly_detector": + detector = AnomalyDetector.parse(parser); + break; + case "anomaly_detector_job": + detectorJob = AnomalyDetectorJob.parse(parser); + break; + case "realtime_detection_task": + if (parser.currentToken() != XContentParser.Token.VALUE_NULL) { + realtimeAdTask = ADTask.parse(parser); + } + break; + case "historical_analysis_task": + if (parser.currentToken() != XContentParser.Token.VALUE_NULL) { + historicalAdTask = ADTask.parse(parser); + } + break; + default: + parser.skipChildren(); + break; + } + } + + return new ToXContentObject[] { + new AnomalyDetector( + id, + version, + detector.getName(), + detector.getDescription(), + detector.getTimeField(), + detector.getIndices(), + detector.getFeatureAttributes(), + detector.getFilterQuery(), + detector.getInterval(), + detector.getWindowDelay(), + detector.getShingleSize(), + detector.getUiMetadata(), + detector.getSchemaVersion(), + detector.getLastUpdateTime(), + null, + detector.getUser(), + detector.getCustomResultIndex(), + detector.getImputationOption() + ), + detectorJob, + historicalAdTask, + realtimeAdTask }; + } + + protected final XContentParser createAdParser(XContent xContent, InputStream data) throws IOException { + return xContent.createParser(TestHelpers.xContentRegistry(), LoggingDeprecationHandler.INSTANCE, data); + } + + public void updateClusterSettings(String settingKey, Object value) throws Exception { + XContentBuilder builder = XContentFactory + .jsonBuilder() + .startObject() + .startObject("persistent") + .field(settingKey, value) + .endObject() + .endObject(); + Request request = new Request("PUT", "_cluster/settings"); + request.setJsonEntity(Strings.toString(builder)); + Response response = client().performRequest(request); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + Thread.sleep(2000); // sleep some time to resolve flaky test + } + + public Response getDetectorProfile(String detectorId, boolean all, String customizedProfile, RestClient client) throws IOException { + return TestHelpers + .makeRequest( + client, + "GET", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detectorId + "/" + RestHandlerUtils.PROFILE + customizedProfile + "?_all=" + all, + null, + "", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + } + + public Response getDetectorProfile(String detectorId) throws IOException { + return getDetectorProfile(detectorId, false, "", client()); + } + + public Response getDetectorProfile(String detectorId, boolean all) throws IOException { + return getDetectorProfile(detectorId, all, "", client()); + } + + public Response getSearchDetectorCount() throws IOException { + return TestHelpers + .makeRequest( + client(), + "GET", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + RestHandlerUtils.COUNT, + null, + "", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + } + + public Response getSearchDetectorMatch(String name) throws IOException { + return TestHelpers + .makeRequest( + client(), + "GET", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + RestHandlerUtils.MATCH, + ImmutableMap.of("name", name), + "", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + } + + public Response searchTopAnomalyResults(String detectorId, boolean historical, String bodyAsJsonString, RestClient client) + throws IOException { + return TestHelpers + .makeRequest( + client, + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + + "/" + + detectorId + + "/" + + RestHandlerUtils.RESULTS + + "/" + + RestHandlerUtils.TOP_ANOMALIES, + Collections.singletonMap("historical", String.valueOf(historical)), + TestHelpers.toHttpEntity(bodyAsJsonString), + new ArrayList<>() + ); + } + + public Response createUser(String name, String password, ArrayList backendRoles) throws IOException { + JsonArray backendRolesString = new JsonArray(); + for (int i = 0; i < backendRoles.size(); i++) { + backendRolesString.add(backendRoles.get(i)); + } + return TestHelpers + .makeRequest( + client(), + "PUT", + "/_opendistro/_security/api/internalusers/" + name, + null, + TestHelpers + .toHttpEntity( + " {\n" + + "\"password\": \"" + + password + + "\",\n" + + "\"backend_roles\": " + + backendRolesString + + ",\n" + + "\"attributes\": {\n" + + "}} " + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + } + + public Response createRoleMapping(String role, ArrayList users) throws IOException { + JsonArray usersString = new JsonArray(); + for (int i = 0; i < users.size(); i++) { + usersString.add(users.get(i)); + } + return TestHelpers + .makeRequest( + client(), + "PUT", + "/_opendistro/_security/api/rolesmapping/" + role, + null, + TestHelpers + .toHttpEntity( + "{\n" + " \"backend_roles\" : [ ],\n" + " \"hosts\" : [ ],\n" + " \"users\" : " + usersString + "\n" + "}" + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + } + + public Response createIndexRole(String role, String index) throws IOException { + return TestHelpers + .makeRequest( + client(), + "PUT", + "/_opendistro/_security/api/roles/" + role, + null, + TestHelpers + .toHttpEntity( + "{\n" + + "\"cluster_permissions\": [\n" + + "],\n" + + "\"index_permissions\": [\n" + + "{\n" + + "\"index_patterns\": [\n" + + "\"" + + index + + "\"\n" + + "],\n" + + "\"dls\": \"\",\n" + + "\"fls\": [],\n" + + "\"masked_fields\": [],\n" + + "\"allowed_actions\": [\n" + + "\"crud\",\n" + + "\"indices:admin/create\"\n" + + "]\n" + + "}\n" + + "],\n" + + "\"tenant_permissions\": []\n" + + "}" + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + } + + public Response createSearchRole(String role, String index) throws IOException { + return TestHelpers + .makeRequest( + client(), + "PUT", + "/_opendistro/_security/api/roles/" + role, + null, + TestHelpers + .toHttpEntity( + "{\n" + + "\"cluster_permissions\": [\n" + + "],\n" + + "\"index_permissions\": [\n" + + "{\n" + + "\"index_patterns\": [\n" + + "\"" + + index + + "\"\n" + + "],\n" + + "\"dls\": \"\",\n" + + "\"fls\": [],\n" + + "\"masked_fields\": [],\n" + + "\"allowed_actions\": [\n" + + "\"indices:data/read/search\"\n" + + "]\n" + + "}\n" + + "],\n" + + "\"tenant_permissions\": []\n" + + "}" + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + } + + public Response createDlsRole(String role, String index) throws IOException { + return TestHelpers + .makeRequest( + client(), + "PUT", + "/_opendistro/_security/api/roles/" + role, + null, + TestHelpers + .toHttpEntity( + "{\n" + + "\"cluster_permissions\": [\n" + + "unlimited\n" + + "],\n" + + "\"index_permissions\": [\n" + + "{\n" + + "\"index_patterns\": [\n" + + "\"" + + index + + "\"\n" + + "],\n" + + "\"dls\": \"\"\"{ \"bool\": { \"must\": { \"match\": { \"foo\": \"bar\" }}}}\"\"\",\n" + + "\"fls\": [],\n" + + "\"masked_fields\": [],\n" + + "\"allowed_actions\": [\n" + + "\"unlimited\"\n" + + "]\n" + + "},\n" + + "{\n" + + "\"index_patterns\": [\n" + + "\"" + + "*" + + "\"\n" + + "],\n" + + "\"dls\": \"\",\n" + + "\"fls\": [],\n" + + "\"masked_fields\": [],\n" + + "\"allowed_actions\": [\n" + + "\"unlimited\"\n" + + "]\n" + + "}\n" + + "],\n" + + "\"tenant_permissions\": []\n" + + "}" + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + } + + public Response deleteUser(String user) throws IOException { + return TestHelpers + .makeRequest( + client(), + "DELETE", + "/_opendistro/_security/api/internalusers/" + user, + null, + "", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + } + + public Response deleteRoleMapping(String user) throws IOException { + return TestHelpers + .makeRequest( + client(), + "DELETE", + "/_opendistro/_security/api/rolesmapping/" + user, + null, + "", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + } + + public Response enableFilterBy() throws IOException { + return TestHelpers + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + TestHelpers + .toHttpEntity( + "{\n" + + " \"persistent\": {\n" + + " \"opendistro.anomaly_detection.filter_by_backend_roles\" : \"true\"\n" + + " }\n" + + "}" + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + } + + public Response disableFilterBy() throws IOException { + return TestHelpers + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + TestHelpers + .toHttpEntity( + "{\n" + + " \"persistent\": {\n" + + " \"opendistro.anomaly_detection.filter_by_backend_roles\" : \"false\"\n" + + " }\n" + + "}" + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + } + + protected AnomalyDetector cloneDetector(AnomalyDetector anomalyDetector, String resultIndex) { + AnomalyDetector detector = new AnomalyDetector( + null, + null, + randomAlphaOfLength(5), + randomAlphaOfLength(10), + anomalyDetector.getTimeField(), + anomalyDetector.getIndices(), + anomalyDetector.getFeatureAttributes(), + anomalyDetector.getFilterQuery(), + anomalyDetector.getInterval(), + anomalyDetector.getWindowDelay(), + anomalyDetector.getShingleSize(), + anomalyDetector.getUiMetadata(), + anomalyDetector.getSchemaVersion(), + Instant.now(), + anomalyDetector.getCategoryFields(), + null, + resultIndex, + anomalyDetector.getImputationOption() + ); + return detector; + } + + protected Response validateAnomalyDetector(AnomalyDetector detector, RestClient client) throws IOException { + return TestHelpers + .makeRequest( + client, + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/_validate", + ImmutableMap.of(), + TestHelpers.toHttpEntity(detector), + null + ); + } + +} diff --git a/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java b/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java index 2c16271de..1004b01a4 100644 --- a/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java @@ -47,8 +47,8 @@ import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.index.IndexNotFoundException; import org.opensearch.search.DocValueFormat; import org.opensearch.search.SearchHit; diff --git a/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java-e b/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java-e new file mode 100644 index 000000000..2788f170b --- /dev/null +++ b/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java-e @@ -0,0 +1,433 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import static java.util.Collections.emptyMap; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import java.io.IOException; +import java.time.temporal.ChronoUnit; +import java.util.*; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.apache.lucene.search.TotalHits; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.model.EntityProfile; +import org.opensearch.ad.model.EntityProfileName; +import org.opensearch.ad.model.EntityState; +import org.opensearch.ad.model.InitProgressProfile; +import org.opensearch.ad.model.ModelProfile; +import org.opensearch.ad.model.ModelProfileOnNode; +import org.opensearch.ad.transport.EntityProfileAction; +import org.opensearch.ad.transport.EntityProfileResponse; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.client.Client; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.aggregations.metrics.InternalMax; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; + +public class EntityProfileRunnerTests extends AbstractTimeSeriesTest { + private AnomalyDetector detector; + private int detectorIntervalMin; + private Client client; + private SecurityClientUtil clientUtil; + private EntityProfileRunner runner; + private Set state; + private Set initNInfo; + private Set model; + private String detectorId; + private String entityValue; + private int requiredSamples; + private AnomalyDetectorJob job; + + private int smallUpdates; + private String categoryField; + private long latestSampleTimestamp; + private long latestActiveTimestamp; + private Boolean isActive; + private String modelId; + private long modelSize; + private String nodeId; + private Entity entity; + + enum InittedEverResultStatus { + UNKNOWN, + INITTED, + NOT_INITTED, + } + + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(AnomalyDetectorJobRunnerTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + + @SuppressWarnings("unchecked") + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + detectorIntervalMin = 3; + + state = new HashSet(); + state.add(EntityProfileName.STATE); + + initNInfo = new HashSet(); + initNInfo.add(EntityProfileName.INIT_PROGRESS); + initNInfo.add(EntityProfileName.ENTITY_INFO); + + model = new HashSet(); + model.add(EntityProfileName.MODELS); + + detectorId = "A69pa3UBHuCbh-emo9oR"; + entityValue = "app-0"; + + categoryField = "a"; + detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(categoryField)); + job = TestHelpers.randomAnomalyDetectorJob(true); + + requiredSamples = 128; + client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); + + runner = new EntityProfileRunner(client, clientUtil, xContentRegistry(), requiredSamples); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + GetRequest request = (GetRequest) args[0]; + ActionListener listener = (ActionListener) args[1]; + + String indexName = request.index(); + if (indexName.equals(CommonName.CONFIG_INDEX)) { + listener.onResponse(TestHelpers.createGetResponse(detector, detector.getId(), CommonName.CONFIG_INDEX)); + } else if (indexName.equals(CommonName.JOB_INDEX)) { + listener.onResponse(TestHelpers.createGetResponse(job, detector.getId(), CommonName.JOB_INDEX)); + } + + return null; + }).when(client).get(any(), any()); + + entity = Entity.createSingleAttributeEntity(categoryField, entityValue); + modelId = entity.getModelId(detectorId).get(); + } + + @SuppressWarnings("unchecked") + private void setUpSearch() { + latestSampleTimestamp = 1_603_989_830_158L; + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + SearchRequest request = (SearchRequest) args[0]; + String indexName = request.indices()[0]; + ActionListener listener = (ActionListener) args[1]; + if (indexName.equals(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS)) { + InternalMax maxAgg = new InternalMax(CommonName.AGG_NAME_MAX_TIME, latestSampleTimestamp, DocValueFormat.RAW, emptyMap()); + InternalAggregations internalAggregations = InternalAggregations.from(Collections.singletonList(maxAgg)); + + SearchHits hits = new SearchHits(new SearchHit[] {}, null, Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections(hits, internalAggregations, null, false, false, null, 1); + + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 30, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + + listener.onResponse(searchResponse); + } + return null; + }).when(client).search(any(), any()); + } + + @SuppressWarnings("unchecked") + private void setUpExecuteEntityProfileAction(InittedEverResultStatus initted) { + smallUpdates = 1; + latestActiveTimestamp = 1603999189758L; + isActive = Boolean.TRUE; + modelId = "T4c3dXUBj-2IZN7itix__entity_" + entityValue; + modelSize = 712480L; + nodeId = "g6pmr547QR-CfpEvO67M4g"; + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[2]; + + EntityProfileResponse.Builder profileResponseBuilder = new EntityProfileResponse.Builder(); + if (InittedEverResultStatus.UNKNOWN == initted) { + profileResponseBuilder.setTotalUpdates(0L); + } else if (InittedEverResultStatus.NOT_INITTED == initted) { + profileResponseBuilder.setTotalUpdates(smallUpdates); + profileResponseBuilder.setLastActiveMs(latestActiveTimestamp); + profileResponseBuilder.setActive(isActive); + } else { + profileResponseBuilder.setTotalUpdates(requiredSamples + 1); + ModelProfileOnNode model = new ModelProfileOnNode(nodeId, new ModelProfile(modelId, entity, modelSize)); + profileResponseBuilder.setModelProfile(model); + } + + listener.onResponse(profileResponseBuilder.build()); + + return null; + }).when(client).execute(any(EntityProfileAction.class), any(), any()); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + SearchRequest request = (SearchRequest) args[0]; + String indexName = request.indices()[0]; + ActionListener listener = (ActionListener) args[1]; + SearchResponse searchResponse = null; + if (indexName.equals(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS)) { + InternalMax maxAgg = new InternalMax(CommonName.AGG_NAME_MAX_TIME, latestSampleTimestamp, DocValueFormat.RAW, emptyMap()); + InternalAggregations internalAggregations = InternalAggregations.from(Collections.singletonList(maxAgg)); + + SearchHits hits = new SearchHits(new SearchHit[] {}, null, Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections(hits, internalAggregations, null, false, false, null, 1); + + searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 30, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + } else { + SearchHits collapsedHits = new SearchHits( + new SearchHit[] { + new SearchHit(2, "ID", Collections.emptyMap(), Collections.emptyMap()), + new SearchHit(3, "ID", Collections.emptyMap(), Collections.emptyMap()) }, + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + 1.0F + ); + + InternalSearchResponse internalSearchResponse = new InternalSearchResponse(collapsedHits, null, null, null, false, null, 1); + searchResponse = new SearchResponse( + internalSearchResponse, + null, + 1, + 1, + 0, + 0, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + } + + listener.onResponse(searchResponse); + + return null; + + }).when(client).search(any(), any()); + } + + public void stateTestTemplate(InittedEverResultStatus returnedState, EntityState expectedState) throws InterruptedException { + setUpExecuteEntityProfileAction(returnedState); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + runner.profile(detectorId, entity, state, ActionListener.wrap(response -> { + assertEquals(expectedState, response.getState()); + inProgressLatch.countDown(); + }, exception -> { + assertTrue("Should not reach here", false); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + public void testRunningState() throws InterruptedException { + stateTestTemplate(InittedEverResultStatus.INITTED, EntityState.RUNNING); + } + + public void testUnknownState() throws InterruptedException { + stateTestTemplate(InittedEverResultStatus.UNKNOWN, EntityState.UNKNOWN); + } + + public void testInitState() throws InterruptedException { + stateTestTemplate(InittedEverResultStatus.NOT_INITTED, EntityState.INIT); + } + + public void testEmptyProfile() throws InterruptedException { + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + runner.profile(detectorId, entity, new HashSet<>(), ActionListener.wrap(response -> { + assertTrue("Should not reach here", false); + inProgressLatch.countDown(); + }, exception -> { + assertTrue(exception.getMessage().contains(ADCommonMessages.EMPTY_PROFILES_COLLECT)); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + public void testModel() throws InterruptedException { + setUpExecuteEntityProfileAction(InittedEverResultStatus.INITTED); + EntityProfile.Builder expectedProfile = new EntityProfile.Builder(); + + ModelProfileOnNode modelProfile = new ModelProfileOnNode(nodeId, new ModelProfile(modelId, entity, modelSize)); + expectedProfile.modelProfile(modelProfile); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + runner.profile(detectorId, entity, model, ActionListener.wrap(response -> { + assertEquals(expectedProfile.build(), response); + inProgressLatch.countDown(); + }, exception -> { + assertTrue("Should not reach here", false); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + public void testEmptyModelProfile() throws IOException { + ModelProfile modelProfile = new ModelProfile(modelId, null, modelSize); + BytesStreamOutput output = new BytesStreamOutput(); + modelProfile.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + ModelProfile readResponse = new ModelProfile(streamInput); + assertEquals("serialization has the wrong model id", modelId, readResponse.getModelId()); + assertTrue("serialization has null entity", null == readResponse.getEntity()); + assertEquals("serialization has the wrong model size", modelSize, readResponse.getModelSizeInBytes()); + + } + + @SuppressWarnings("unchecked") + public void testJobIndexNotFound() throws InterruptedException { + setUpExecuteEntityProfileAction(InittedEverResultStatus.INITTED); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + GetRequest request = (GetRequest) args[0]; + ActionListener listener = (ActionListener) args[1]; + + String indexName = request.index(); + if (indexName.equals(CommonName.CONFIG_INDEX)) { + listener.onResponse(TestHelpers.createGetResponse(detector, detector.getId(), CommonName.CONFIG_INDEX)); + } else if (indexName.equals(CommonName.JOB_INDEX)) { + listener.onFailure(new IndexNotFoundException(CommonName.JOB_INDEX)); + } + + return null; + }).when(client).get(any(), any()); + + EntityProfile expectedProfile = new EntityProfile.Builder().build(); + + runner.profile(detectorId, entity, initNInfo, ActionListener.wrap(response -> { + assertEquals(expectedProfile, response); + inProgressLatch.countDown(); + }, exception -> { + LOG.error("Unexpected error", exception); + assertTrue("Should not reach here", false); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + @SuppressWarnings("unchecked") + public void testNotMultiEntityDetector() throws IOException, InterruptedException { + detector = TestHelpers.randomAnomalyDetectorWithInterval(new IntervalTimeConfiguration(detectorIntervalMin, ChronoUnit.MINUTES)); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + GetRequest request = (GetRequest) args[0]; + ActionListener listener = (ActionListener) args[1]; + + String indexName = request.index(); + if (indexName.equals(CommonName.CONFIG_INDEX)) { + listener.onResponse(TestHelpers.createGetResponse(detector, detector.getId(), CommonName.CONFIG_INDEX)); + } + + return null; + }).when(client).get(any(), any()); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + runner.profile(detectorId, entity, state, ActionListener.wrap(response -> { + assertTrue("Should not reach here", false); + inProgressLatch.countDown(); + }, exception -> { + assertTrue(exception.getMessage().contains(EntityProfileRunner.NOT_HC_DETECTOR_ERR_MSG)); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + public void testInitNInfo() throws InterruptedException { + setUpExecuteEntityProfileAction(InittedEverResultStatus.NOT_INITTED); + latestSampleTimestamp = 1_603_989_830_158L; + + EntityProfile.Builder expectedProfile = new EntityProfile.Builder(); + + // 1 / 128 rounded to 1% + int neededSamples = requiredSamples - smallUpdates; + InitProgressProfile profile = new InitProgressProfile("1%", neededSamples * detector.getIntervalInSeconds() / 60, neededSamples); + expectedProfile.initProgress(profile); + expectedProfile.isActive(isActive); + expectedProfile.lastActiveTimestampMs(latestActiveTimestamp); + expectedProfile.lastSampleTimestampMs(latestSampleTimestamp); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + runner.profile(detectorId, entity, initNInfo, ActionListener.wrap(response -> { + assertEquals(expectedProfile.build(), response); + inProgressLatch.countDown(); + }, exception -> { + LOG.error("Unexpected error", exception); + assertTrue("Should not reach here", false); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } +} diff --git a/src/test/java/org/opensearch/ad/HistoricalAnalysisIntegTestCase.java b/src/test/java/org/opensearch/ad/HistoricalAnalysisIntegTestCase.java index 119337846..b19eb7242 100644 --- a/src/test/java/org/opensearch/ad/HistoricalAnalysisIntegTestCase.java +++ b/src/test/java/org/opensearch/ad/HistoricalAnalysisIntegTestCase.java @@ -44,10 +44,10 @@ import org.opensearch.ad.transport.AnomalyDetectorJobAction; import org.opensearch.ad.transport.AnomalyDetectorJobRequest; import org.opensearch.ad.transport.AnomalyDetectorJobResponse; +import org.opensearch.core.rest.RestStatus; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.plugins.Plugin; -import org.opensearch.rest.RestStatus; import org.opensearch.search.SearchHit; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.builder.SearchSourceBuilder; diff --git a/src/test/java/org/opensearch/ad/HistoricalAnalysisIntegTestCase.java-e b/src/test/java/org/opensearch/ad/HistoricalAnalysisIntegTestCase.java-e new file mode 100644 index 000000000..d76d708b1 --- /dev/null +++ b/src/test/java/org/opensearch/ad/HistoricalAnalysisIntegTestCase.java-e @@ -0,0 +1,257 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import static org.opensearch.ad.model.ADTask.DETECTOR_ID_FIELD; +import static org.opensearch.ad.model.ADTask.EXECUTION_START_TIME_FIELD; +import static org.opensearch.ad.model.ADTask.IS_LATEST_FIELD; +import static org.opensearch.ad.model.ADTask.PARENT_TASK_ID_FIELD; +import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM; +import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; +import static org.opensearch.timeseries.util.RestHandlerUtils.START_JOB; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.mock.plugin.MockReindexPlugin; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskState; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.transport.AnomalyDetectorJobAction; +import org.opensearch.ad.transport.AnomalyDetectorJobRequest; +import org.opensearch.ad.transport.AnomalyDetectorJobResponse; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.plugins.Plugin; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.search.SearchHit; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.test.transport.MockTransportService; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.Feature; + +import com.google.common.collect.ImmutableList; + +public abstract class HistoricalAnalysisIntegTestCase extends ADIntegTestCase { + + protected String testIndex = "test_historical_data"; + protected int detectionIntervalInMinutes = 1; + protected int DEFAULT_TEST_DATA_DOCS = 3000; + protected String DEFAULT_IP = "127.0.0.1"; + + @Override + protected Collection> getMockPlugins() { + final ArrayList> plugins = new ArrayList<>(); + plugins.add(MockReindexPlugin.class); + plugins.addAll(super.getMockPlugins()); + plugins.remove(MockTransportService.TestPlugin.class); + return Collections.unmodifiableList(plugins); + } + + public void ingestTestData(String testIndex, Instant startTime, int detectionIntervalInMinutes, String type) { + ingestTestData(testIndex, startTime, detectionIntervalInMinutes, type, DEFAULT_IP, DEFAULT_TEST_DATA_DOCS, true); + } + + public void ingestTestData(String testIndex, Instant startTime, int detectionIntervalInMinutes, String type, int totalDocs) { + ingestTestData(testIndex, startTime, detectionIntervalInMinutes, type, DEFAULT_IP, totalDocs, true); + } + + public void ingestTestData( + String testIndex, + Instant startTime, + int detectionIntervalInMinutes, + String type, + String ip, + int totalDocs, + boolean createIndexFirst + ) { + if (createIndexFirst) { + createTestDataIndex(testIndex); + } + List> docs = new ArrayList<>(); + Instant currentInterval = Instant.from(startTime); + + for (int i = 0; i < totalDocs; i++) { + currentInterval = currentInterval.plus(detectionIntervalInMinutes, ChronoUnit.MINUTES); + double value = i % 500 == 0 ? randomDoubleBetween(1000, 2000, true) : randomDoubleBetween(10, 100, true); + Map doc = new HashMap<>(); + doc.put(timeField, currentInterval.toEpochMilli()); + doc.put("value", value); + doc.put("ip", ip); + doc.put("type", type); + doc.put("is_error", randomBoolean()); + doc.put("message", randomAlphaOfLength(5)); + docs.add(doc); + } + BulkResponse bulkResponse = bulkIndexDocs(testIndex, docs, 30_000); + assertEquals(RestStatus.OK, bulkResponse.status()); + assertFalse(bulkResponse.hasFailures()); + long count = countDocs(testIndex); + if (createIndexFirst) { + assertEquals(totalDocs, count); + } + } + + public Feature maxValueFeature() throws IOException { + AggregationBuilder aggregationBuilder = TestHelpers.parseAggregation("{\"test\":{\"max\":{\"field\":\"" + valueField + "\"}}}"); + return new Feature(randomAlphaOfLength(5), randomAlphaOfLength(10), true, aggregationBuilder); + } + + public AnomalyDetector randomDetector(List features) throws IOException { + return TestHelpers.randomDetector(features, testIndex, detectionIntervalInMinutes, timeField); + } + + public ADTask randomCreatedADTask(String taskId, AnomalyDetector detector, DateRange detectionDateRange) { + String detectorId = detector == null ? null : detector.getId(); + return randomCreatedADTask(taskId, detector, detectorId, detectionDateRange); + } + + public ADTask randomCreatedADTask(String taskId, AnomalyDetector detector, String detectorId, DateRange detectionDateRange) { + return randomADTask(taskId, detector, detectorId, detectionDateRange, ADTaskState.CREATED); + } + + public ADTask randomADTask( + String taskId, + AnomalyDetector detector, + String detectorId, + DateRange detectionDateRange, + ADTaskState state + ) { + ADTask.Builder builder = ADTask + .builder() + .taskId(taskId) + .taskType(ADTaskType.HISTORICAL_SINGLE_ENTITY.name()) + .detectorId(detectorId) + .detectionDateRange(detectionDateRange) + .detector(detector) + .state(state.name()) + .taskProgress(0.0f) + .initProgress(0.0f) + .isLatest(true) + .startedBy(randomAlphaOfLength(5)) + .executionStartTime(Instant.now().minus(randomLongBetween(10, 100), ChronoUnit.MINUTES)); + if (ADTaskState.FINISHED == state) { + setPropertyForNotRunningTask(builder); + } else if (ADTaskState.FAILED == state) { + setPropertyForNotRunningTask(builder); + builder.error(randomAlphaOfLength(5)); + } else if (ADTaskState.STOPPED == state) { + setPropertyForNotRunningTask(builder); + builder.error(randomAlphaOfLength(5)); + builder.stoppedBy(randomAlphaOfLength(5)); + } + return builder.build(); + } + + private ADTask.Builder setPropertyForNotRunningTask(ADTask.Builder builder) { + builder.executionEndTime(Instant.now().minus(randomLongBetween(1, 5), ChronoUnit.MINUTES)); + builder.isLatest(false); + return builder; + } + + public List searchADTasks(String detectorId, Boolean isLatest, int size) throws IOException { + return searchADTasks(detectorId, null, isLatest, size); + } + + public List searchADTasks(String detectorId, String parentTaskId, Boolean isLatest, int size) throws IOException { + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, detectorId)); + if (isLatest != null) { + query.filter(new TermQueryBuilder(IS_LATEST_FIELD, isLatest)); + } + if (parentTaskId != null) { + query.filter(new TermQueryBuilder(PARENT_TASK_ID_FIELD, parentTaskId)); + } + SearchRequest searchRequest = new SearchRequest(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(query).sort(EXECUTION_START_TIME_FIELD, SortOrder.DESC).trackTotalHits(true).size(size); + searchRequest.source(sourceBuilder).indices(ADCommonName.DETECTION_STATE_INDEX); + SearchResponse searchResponse = client().search(searchRequest).actionGet(); + Iterator iterator = searchResponse.getHits().iterator(); + + List adTasks = new ArrayList<>(); + while (iterator.hasNext()) { + SearchHit next = iterator.next(); + ADTask task = ADTask.parse(TestHelpers.parser(next.getSourceAsString()), next.getId()); + adTasks.add(task); + } + return adTasks; + } + + public ADTask getADTask(String taskId) throws IOException { + ADTask adTask = toADTask(getDoc(ADCommonName.DETECTION_STATE_INDEX, taskId)); + adTask.setTaskId(taskId); + return adTask; + } + + public AnomalyDetectorJob getADJob(String detectorId) throws IOException { + return toADJob(getDoc(CommonName.JOB_INDEX, detectorId)); + } + + public ADTask toADTask(GetResponse doc) throws IOException { + return ADTask.parse(TestHelpers.parser(doc.getSourceAsString())); + } + + public AnomalyDetectorJob toADJob(GetResponse doc) throws IOException { + return AnomalyDetectorJob.parse(TestHelpers.parser(doc.getSourceAsString())); + } + + public ADTask startHistoricalAnalysis(Instant startTime, Instant endTime) throws IOException { + DateRange dateRange = new DateRange(startTime, endTime); + AnomalyDetector detector = TestHelpers + .randomDetector(ImmutableList.of(maxValueFeature()), testIndex, detectionIntervalInMinutes, timeField); + String detectorId = createDetector(detector); + AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( + detectorId, + dateRange, + true, + UNASSIGNED_SEQ_NO, + UNASSIGNED_PRIMARY_TERM, + START_JOB + ); + AnomalyDetectorJobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); + return getADTask(response.getId()); + } + + public ADTask startHistoricalAnalysis(String detectorId, Instant startTime, Instant endTime) throws IOException { + DateRange dateRange = new DateRange(startTime, endTime); + AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( + detectorId, + dateRange, + true, + UNASSIGNED_SEQ_NO, + UNASSIGNED_PRIMARY_TERM, + START_JOB + ); + AnomalyDetectorJobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); + return getADTask(response.getId()); + } +} diff --git a/src/test/java/org/opensearch/ad/HistoricalAnalysisRestTestCase.java b/src/test/java/org/opensearch/ad/HistoricalAnalysisRestTestCase.java index 464f6868c..17ecaa216 100644 --- a/src/test/java/org/opensearch/ad/HistoricalAnalysisRestTestCase.java +++ b/src/test/java/org/opensearch/ad/HistoricalAnalysisRestTestCase.java @@ -32,9 +32,9 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.client.Response; import org.opensearch.client.RestClient; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.rest.RestStatus; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.model.DateRange; diff --git a/src/test/java/org/opensearch/ad/HistoricalAnalysisRestTestCase.java-e b/src/test/java/org/opensearch/ad/HistoricalAnalysisRestTestCase.java-e new file mode 100644 index 000000000..7e60242ca --- /dev/null +++ b/src/test/java/org/opensearch/ad/HistoricalAnalysisRestTestCase.java-e @@ -0,0 +1,271 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.ToDoubleFunction; + +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.ParseException; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Before; +import org.opensearch.ad.mock.model.MockSimpleLog; +import org.opensearch.ad.model.ADTaskProfile; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.client.Response; +import org.opensearch.client.RestClient; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.Feature; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +public abstract class HistoricalAnalysisRestTestCase extends AnomalyDetectorRestTestCase { + + public static final int MAX_RETRY_TIMES = 200; + protected String historicalAnalysisTestIndex = "test_historical_analysis_data"; + protected int detectionIntervalInMinutes = 1; + protected int categoryFieldDocCount = 2; + + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + updateClusterSettings(BATCH_TASK_PIECE_INTERVAL_SECONDS.getKey(), 1); + // ingest test data + ingestTestDataForHistoricalAnalysis(historicalAnalysisTestIndex, detectionIntervalInMinutes); + } + + public ToXContentObject[] getHistoricalAnomalyDetector(String detectorId, boolean returnTask, RestClient client) throws IOException { + BasicHeader header = new BasicHeader(HttpHeaders.CONTENT_TYPE, "application/json"); + return getAnomalyDetector(detectorId, header, false, returnTask, client); + } + + public ADTaskProfile getADTaskProfile(String detectorId) throws IOException, ParseException { + Response profileResponse = TestHelpers + .makeRequest( + client(), + "GET", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detectorId + "/_profile/ad_task", + ImmutableMap.of(), + "", + null + ); + return parseADTaskProfile(profileResponse); + } + + public Response searchTaskResult(String resultIndex, String taskId) throws IOException { + Response response = TestHelpers + .makeRequest( + client(), + "GET", + TestHelpers.AD_BASE_RESULT_URI + "/_search/" + resultIndex, + ImmutableMap.of(), + TestHelpers + .toHttpEntity( + "{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"task_id\":\"" + taskId + "\"}}]}},\"track_total_hits\":true}" + ), + null + ); + return response; + } + + public Response ingestSimpleMockLog( + String indexName, + int startDays, + int totalDoc, + long intervalInMinutes, + ToDoubleFunction valueFunc, + int ipSize, + int categorySize + ) throws IOException, + ParseException { + TestHelpers + .makeRequest( + client(), + "PUT", + indexName, + null, + TestHelpers.toHttpEntity(MockSimpleLog.INDEX_MAPPING), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + + Response statsResponse = TestHelpers.makeRequest(client(), "GET", indexName, ImmutableMap.of(), "", null); + assertEquals(RestStatus.OK, TestHelpers.restStatus(statsResponse)); + String result = EntityUtils.toString(statsResponse.getEntity()); + assertTrue(result.contains(indexName)); + + StringBuilder bulkRequestBuilder = new StringBuilder(); + Instant startTime = Instant.now().minus(startDays, ChronoUnit.DAYS); + for (int i = 0; i < totalDoc; i++) { + for (int m = 0; m < ipSize; m++) { + String ip = "192.168.1." + m; + for (int n = 0; n < categorySize; n++) { + String category = "category" + n; + String docId = randomAlphaOfLength(10); + bulkRequestBuilder.append("{ \"index\" : { \"_index\" : \"" + indexName + "\", \"_id\" : \"" + docId + "\" } }\n"); + MockSimpleLog simpleLog1 = new MockSimpleLog( + startTime, + valueFunc.applyAsDouble(i), + ip, + category, + randomBoolean(), + randomAlphaOfLength(5) + ); + bulkRequestBuilder.append(TestHelpers.toJsonString(simpleLog1)); + bulkRequestBuilder.append("\n"); + } + } + startTime = startTime.plus(intervalInMinutes, ChronoUnit.MINUTES); + } + Response bulkResponse = TestHelpers + .makeRequest( + client(), + "POST", + "_bulk?refresh=true", + null, + TestHelpers.toHttpEntity(bulkRequestBuilder.toString()), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + return bulkResponse; + } + + public ADTaskProfile parseADTaskProfile(Response profileResponse) throws IOException, ParseException { + String profileResult = EntityUtils.toString(profileResponse.getEntity()); + XContentParser parser = TestHelpers.parser(profileResult); + ADTaskProfile adTaskProfile = null; + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + if ("ad_task".equals(fieldName)) { + adTaskProfile = ADTaskProfile.parse(parser); + } else { + parser.skipChildren(); + } + } + return adTaskProfile; + } + + protected void ingestTestDataForHistoricalAnalysis(String indexName, int detectionIntervalInMinutes) throws IOException, + ParseException { + ingestSimpleMockLog(indexName, 10, 3000, detectionIntervalInMinutes, (i) -> { + if (i % 500 == 0) { + return randomDoubleBetween(100, 1000, true); + } else { + return randomDoubleBetween(1, 10, true); + } + }, categoryFieldDocCount, categoryFieldDocCount); + } + + protected AnomalyDetector createAnomalyDetector() throws IOException, IllegalAccessException { + return createAnomalyDetector(0); + } + + protected AnomalyDetector createAnomalyDetector(int categoryFieldSize) throws IOException, IllegalAccessException { + return createAnomalyDetector(categoryFieldSize, null); + } + + protected AnomalyDetector createAnomalyDetector(int categoryFieldSize, String resultIndex) throws IOException, IllegalAccessException { + AggregationBuilder aggregationBuilder = TestHelpers + .parseAggregation("{\"test\":{\"max\":{\"field\":\"" + MockSimpleLog.VALUE_FIELD + "\"}}}"); + Feature feature = new Feature(randomAlphaOfLength(5), randomAlphaOfLength(10), true, aggregationBuilder); + List categoryField = null; + switch (categoryFieldSize) { + case 0: + break; + case 1: + categoryField = ImmutableList.of(MockSimpleLog.CATEGORY_FIELD); + break; + case 2: + categoryField = ImmutableList.of(MockSimpleLog.IP_FIELD, MockSimpleLog.CATEGORY_FIELD); + break; + default: + throw new IllegalAccessException("Wrong category field size"); + } + AnomalyDetector detector = TestHelpers + .randomDetector( + ImmutableList.of(feature), + historicalAnalysisTestIndex, + detectionIntervalInMinutes, + MockSimpleLog.TIME_FIELD, + categoryField, + resultIndex + ); + return createAnomalyDetector(detector, true, client()); + } + + protected String startHistoricalAnalysis(String detectorId) throws IOException { + Instant endTime = Instant.now().truncatedTo(ChronoUnit.SECONDS); + Instant startTime = endTime.minus(10, ChronoUnit.DAYS).truncatedTo(ChronoUnit.SECONDS); + DateRange dateRange = new DateRange(startTime, endTime); + Response startDetectorResponse = startAnomalyDetector(detectorId, dateRange, client()); + Map startDetectorResponseMap = responseAsMap(startDetectorResponse); + String taskId = (String) startDetectorResponseMap.get("_id"); + assertNotNull(taskId); + return taskId; + } + + protected ADTaskProfile waitUntilGetTaskProfile(String detectorId) throws InterruptedException { + int i = 0; + ADTaskProfile adTaskProfile = null; + while (adTaskProfile == null && i < 200) { + try { + adTaskProfile = getADTaskProfile(detectorId); + } catch (Exception e) {} finally { + Thread.sleep(100); + } + i++; + } + assertNotNull(adTaskProfile); + return adTaskProfile; + } + + // TODO: change response to pair + protected List waitUntilTaskDone(String detectorId) throws InterruptedException { + return waitUntilTaskReachState(detectorId, TestHelpers.HISTORICAL_ANALYSIS_DONE_STATS); + } + + protected List waitUntilTaskReachState(String detectorId, Set targetStates) throws InterruptedException { + List results = new ArrayList<>(); + int i = 0; + ADTaskProfile adTaskProfile = null; + // Increase retryTimes if some task can't reach done state + while ((adTaskProfile == null || !targetStates.contains(adTaskProfile.getAdTask().getState())) && i < MAX_RETRY_TIMES) { + try { + adTaskProfile = getADTaskProfile(detectorId); + } catch (Exception e) { + logger.error("failed to get ADTaskProfile", e); + } finally { + Thread.sleep(1000); + } + i++; + } + assertNotNull(adTaskProfile); + results.add(adTaskProfile); + results.add(i); + return results; + } +} diff --git a/src/test/java/org/opensearch/ad/MemoryTrackerTests.java-e b/src/test/java/org/opensearch/ad/MemoryTrackerTests.java-e new file mode 100644 index 000000000..f21b74b11 --- /dev/null +++ b/src/test/java/org/opensearch/ad/MemoryTrackerTests.java-e @@ -0,0 +1,328 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; + +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.ByteSizeValue; +import org.opensearch.monitor.jvm.JvmInfo; +import org.opensearch.monitor.jvm.JvmInfo.Mem; +import org.opensearch.monitor.jvm.JvmService; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + +import com.amazon.randomcutforest.config.Precision; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public class MemoryTrackerTests extends OpenSearchTestCase { + + int inputFeatures; + int rcfSampleSize; + int numberOfTrees; + double rcfTimeDecay; + int numMinSamples; + int shingleSize; + int dimension; + MemoryTracker tracker; + long expectedRCFModelSize; + String detectorId; + long largeHeapSize; + long smallHeapSize; + Mem mem; + ThresholdedRandomCutForest trcf; + float modelMaxPercen; + ClusterService clusterService; + double modelMaxSizePercentage; + double modelDesiredSizePercentage; + JvmService jvmService; + AnomalyDetector detector; + ADCircuitBreakerService circuitBreaker; + + @Override + public void setUp() throws Exception { + super.setUp(); + inputFeatures = 1; + rcfSampleSize = 256; + numberOfTrees = 30; + rcfTimeDecay = 0.2; + numMinSamples = 128; + shingleSize = 8; + dimension = inputFeatures * shingleSize; + + jvmService = mock(JvmService.class); + JvmInfo info = mock(JvmInfo.class); + mem = mock(Mem.class); + // 800 MB is the limit + largeHeapSize = 800_000_000; + smallHeapSize = 1_000_000; + + when(jvmService.info()).thenReturn(info); + when(info.getMem()).thenReturn(mem); + + modelMaxSizePercentage = 0.1; + modelDesiredSizePercentage = 0.0002; + + clusterService = mock(ClusterService.class); + modelMaxPercen = 0.1f; + Settings settings = Settings.builder().put(AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE.getKey(), modelMaxPercen).build(); + ClusterSettings clusterSettings = new ClusterSettings( + settings, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE))) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + expectedRCFModelSize = 382784; + detectorId = "123"; + + trcf = ThresholdedRandomCutForest + .builder() + .dimensions(dimension) + .sampleSize(rcfSampleSize) + .numberOfTrees(numberOfTrees) + .timeDecay(rcfTimeDecay) + .outputAfter(numMinSamples) + .initialAcceptFraction(numMinSamples * 1.0d / rcfSampleSize) + .parallelExecutionEnabled(false) + .compact(true) + .precision(Precision.FLOAT_32) + .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) + .shingleSize(shingleSize) + .internalShinglingEnabled(true) + .build(); + + detector = mock(AnomalyDetector.class); + when(detector.getEnabledFeatureIds()).thenReturn(Collections.singletonList("a")); + when(detector.getShingleSize()).thenReturn(1); + + circuitBreaker = mock(ADCircuitBreakerService.class); + when(circuitBreaker.isOpen()).thenReturn(false); + } + + private void setUpBigHeap() { + ByteSizeValue value = new ByteSizeValue(largeHeapSize); + when(mem.getHeapMax()).thenReturn(value); + tracker = new MemoryTracker(jvmService, modelMaxSizePercentage, modelDesiredSizePercentage, clusterService, circuitBreaker); + } + + private void setUpSmallHeap() { + ByteSizeValue value = new ByteSizeValue(smallHeapSize); + when(mem.getHeapMax()).thenReturn(value); + tracker = new MemoryTracker(jvmService, modelMaxSizePercentage, modelDesiredSizePercentage, clusterService, circuitBreaker); + } + + public void testEstimateModelSize() { + setUpBigHeap(); + + assertEquals(403491, tracker.estimateTRCFModelSize(trcf)); + assertTrue(tracker.isHostingAllowed(detectorId, trcf)); + + ThresholdedRandomCutForest rcf2 = ThresholdedRandomCutForest + .builder() + .dimensions(32) // 32 to trigger another calculation of point store usage + .sampleSize(rcfSampleSize) + .numberOfTrees(numberOfTrees) + .timeDecay(rcfTimeDecay) + .outputAfter(numMinSamples) + .initialAcceptFraction(numMinSamples * 1.0d / rcfSampleSize) + .parallelExecutionEnabled(false) + .compact(true) + .precision(Precision.FLOAT_32) + .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) + .internalShinglingEnabled(true) + // same with dimension for opportunistic memory saving + .shingleSize(shingleSize) + .build(); + assertEquals(603708, tracker.estimateTRCFModelSize(rcf2)); + assertTrue(tracker.isHostingAllowed(detectorId, rcf2)); + + ThresholdedRandomCutForest rcf3 = ThresholdedRandomCutForest + .builder() + .dimensions(9) + .sampleSize(rcfSampleSize) + .numberOfTrees(numberOfTrees) + .timeDecay(rcfTimeDecay) + .outputAfter(numMinSamples) + .initialAcceptFraction(numMinSamples * 1.0d / rcfSampleSize) + .parallelExecutionEnabled(false) + .compact(true) + .precision(Precision.FLOAT_32) + .boundingBoxCacheFraction(AnomalyDetectorSettings.BATCH_BOUNDING_BOX_CACHE_RATIO) + .internalShinglingEnabled(false) + // same with dimension for opportunistic memory saving + .shingleSize(1) + .build(); + assertEquals(1685208, tracker.estimateTRCFModelSize(rcf3)); + + ThresholdedRandomCutForest rcf4 = ThresholdedRandomCutForest + .builder() + .dimensions(6) + .sampleSize(rcfSampleSize) + .numberOfTrees(numberOfTrees) + .timeDecay(rcfTimeDecay) + .outputAfter(numMinSamples) + .parallelExecutionEnabled(false) + .compact(true) + .precision(Precision.FLOAT_32) + .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) + .internalShinglingEnabled(true) + // same with dimension for opportunistic memory saving + .shingleSize(1) + .build(); + assertEquals(521304, tracker.estimateTRCFModelSize(rcf4)); + + ThresholdedRandomCutForest rcf5 = ThresholdedRandomCutForest + .builder() + .dimensions(8) + .sampleSize(rcfSampleSize) + .numberOfTrees(numberOfTrees) + .timeDecay(rcfTimeDecay) + .outputAfter(numMinSamples) + .parallelExecutionEnabled(false) + .compact(true) + .precision(Precision.FLOAT_32) + .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) + .internalShinglingEnabled(true) + // same with dimension for opportunistic memory saving + .shingleSize(2) + .build(); + assertEquals(467340, tracker.estimateTRCFModelSize(rcf5)); + + ThresholdedRandomCutForest rcf6 = ThresholdedRandomCutForest + .builder() + .dimensions(32) + .sampleSize(rcfSampleSize) + .numberOfTrees(numberOfTrees) + .timeDecay(rcfTimeDecay) + .outputAfter(numMinSamples) + .parallelExecutionEnabled(false) + .compact(true) + .precision(Precision.FLOAT_32) + .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) + .internalShinglingEnabled(true) + // same with dimension for opportunistic memory saving + .shingleSize(4) + .build(); + assertEquals(603676, tracker.estimateTRCFModelSize(rcf6)); + + ThresholdedRandomCutForest rcf7 = ThresholdedRandomCutForest + .builder() + .dimensions(16) + .sampleSize(rcfSampleSize) + .numberOfTrees(numberOfTrees) + .timeDecay(rcfTimeDecay) + .outputAfter(numMinSamples) + .parallelExecutionEnabled(false) + .compact(true) + .precision(Precision.FLOAT_32) + .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) + .internalShinglingEnabled(true) + // same with dimension for opportunistic memory saving + .shingleSize(16) + .build(); + assertEquals(401481, tracker.estimateTRCFModelSize(rcf7)); + + ThresholdedRandomCutForest rcf8 = ThresholdedRandomCutForest + .builder() + .dimensions(320) + .sampleSize(rcfSampleSize) + .numberOfTrees(numberOfTrees) + .timeDecay(rcfTimeDecay) + .outputAfter(numMinSamples) + .parallelExecutionEnabled(false) + .compact(true) + .precision(Precision.FLOAT_32) + .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) + .internalShinglingEnabled(true) + // same with dimension for opportunistic memory saving + .shingleSize(32) + .build(); + assertEquals(1040432, tracker.estimateTRCFModelSize(rcf8)); + + ThresholdedRandomCutForest rcf9 = ThresholdedRandomCutForest + .builder() + .dimensions(320) + .sampleSize(rcfSampleSize) + .numberOfTrees(numberOfTrees) + .timeDecay(rcfTimeDecay) + .outputAfter(numMinSamples) + .parallelExecutionEnabled(false) + .compact(true) + .precision(Precision.FLOAT_32) + .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) + .internalShinglingEnabled(true) + // same with dimension for opportunistic memory saving + .shingleSize(64) + .build(); + assertEquals(1040688, tracker.estimateTRCFModelSize(rcf9)); + + ThresholdedRandomCutForest rcf10 = ThresholdedRandomCutForest + .builder() + .dimensions(325) + .sampleSize(rcfSampleSize) + .numberOfTrees(numberOfTrees) + .timeDecay(rcfTimeDecay) + .outputAfter(numMinSamples) + .parallelExecutionEnabled(false) + .compact(true) + .precision(Precision.FLOAT_32) + .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) + .internalShinglingEnabled(true) + // same with dimension for opportunistic memory saving + .shingleSize(65) + .build(); + expectThrows(IllegalArgumentException.class, () -> tracker.estimateTRCFModelSize(rcf10)); + } + + public void testCanAllocate() { + setUpBigHeap(); + + assertTrue(tracker.canAllocate((long) (largeHeapSize * modelMaxPercen))); + assertTrue(!tracker.canAllocate((long) (largeHeapSize * modelMaxPercen + 10))); + + long bytesToUse = 100_000; + tracker.consumeMemory(bytesToUse, false, MemoryTracker.Origin.HC_DETECTOR); + assertTrue(!tracker.canAllocate((long) (largeHeapSize * modelMaxPercen))); + + tracker.releaseMemory(bytesToUse, false, MemoryTracker.Origin.HC_DETECTOR); + assertTrue(tracker.canAllocate((long) (largeHeapSize * modelMaxPercen))); + } + + public void testCannotHost() { + setUpSmallHeap(); + expectThrows(LimitExceededException.class, () -> tracker.isHostingAllowed(detectorId, trcf)); + } + + public void testMemoryToShed() { + setUpSmallHeap(); + long bytesToUse = 100_000; + assertEquals(bytesToUse, tracker.getHeapLimit()); + assertEquals((long) (smallHeapSize * modelDesiredSizePercentage), tracker.getDesiredModelSize()); + tracker.consumeMemory(bytesToUse, false, MemoryTracker.Origin.HC_DETECTOR); + tracker.consumeMemory(bytesToUse, true, MemoryTracker.Origin.HC_DETECTOR); + assertEquals(2 * bytesToUse, tracker.getTotalMemoryBytes()); + + assertEquals(bytesToUse, tracker.memoryToShed()); + assertTrue(!tracker.syncMemoryState(MemoryTracker.Origin.HC_DETECTOR, 2 * bytesToUse, bytesToUse)); + } +} diff --git a/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java-e b/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java-e new file mode 100644 index 000000000..80ef180ed --- /dev/null +++ b/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java-e @@ -0,0 +1,334 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; + +import java.time.Clock; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.model.DetectorInternalState; +import org.opensearch.ad.model.DetectorProfile; +import org.opensearch.ad.model.DetectorProfileName; +import org.opensearch.ad.model.DetectorState; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.AnomalyResultTests; +import org.opensearch.ad.transport.ProfileAction; +import org.opensearch.ad.transport.ProfileNodeResponse; +import org.opensearch.ad.transport.ProfileResponse; +import org.opensearch.ad.util.*; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.transport.TransportAddress; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.transport.TransportService; + +public class MultiEntityProfileRunnerTests extends AbstractTimeSeriesTest { + private AnomalyDetectorProfileRunner runner; + private Client client; + private SecurityClientUtil clientUtil; + private DiscoveryNodeFilterer nodeFilter; + private int requiredSamples; + private AnomalyDetector detector; + private String detectorId; + private Set stateNError; + private DetectorInternalState.Builder result; + private String node1; + private String nodeName1; + private DiscoveryNode discoveryNode1; + + private String node2; + private String nodeName2; + private DiscoveryNode discoveryNode2; + + private long modelSize; + private String model1Id; + private String model0Id; + + private int shingleSize; + private AnomalyDetectorJob job; + private TransportService transportService; + private ADTaskManager adTaskManager; + + enum InittedEverResultStatus { + INITTED, + NOT_INITTED, + } + + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(AnomalyResultTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + + @SuppressWarnings("unchecked") + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + client = mock(Client.class); + Clock clock = mock(Clock.class); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); + nodeFilter = mock(DiscoveryNodeFilterer.class); + requiredSamples = 128; + + detectorId = "A69pa3UBHuCbh-emo9oR"; + detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList("a")); + result = new DetectorInternalState.Builder().lastUpdateTime(Instant.now()); + job = TestHelpers.randomAnomalyDetectorJob(true); + adTaskManager = mock(ADTaskManager.class); + transportService = mock(TransportService.class); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + Consumer> function = (Consumer>) args[2]; + + function.accept(Optional.of(TestHelpers.randomAdTask())); + return null; + }).when(adTaskManager).getAndExecuteOnLatestDetectorLevelTask(any(), any(), any(), any(), anyBoolean(), any()); + runner = new AnomalyDetectorProfileRunner( + client, + clientUtil, + xContentRegistry(), + nodeFilter, + requiredSamples, + transportService, + adTaskManager + ); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + GetRequest request = (GetRequest) args[0]; + ActionListener listener = (ActionListener) args[1]; + + String indexName = request.index(); + if (indexName.equals(CommonName.CONFIG_INDEX)) { + listener.onResponse(TestHelpers.createGetResponse(detector, detector.getId(), CommonName.CONFIG_INDEX)); + } else if (indexName.equals(ADCommonName.DETECTION_STATE_INDEX)) { + listener.onResponse(TestHelpers.createGetResponse(result.build(), detector.getId(), ADCommonName.DETECTION_STATE_INDEX)); + } else if (indexName.equals(CommonName.JOB_INDEX)) { + listener.onResponse(TestHelpers.createGetResponse(job, detector.getId(), CommonName.JOB_INDEX)); + } + + return null; + }).when(client).get(any(), any()); + + stateNError = new HashSet(); + stateNError.add(DetectorProfileName.ERROR); + stateNError.add(DetectorProfileName.STATE); + } + + @SuppressWarnings("unchecked") + private void setUpClientExecuteProfileAction(InittedEverResultStatus initted) { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[2]; + + node1 = "node1"; + nodeName1 = "nodename1"; + discoveryNode1 = new DiscoveryNode( + nodeName1, + node1, + new TransportAddress(TransportAddress.META_ADDRESS, 9300), + emptyMap(), + emptySet(), + Version.CURRENT + ); + + node2 = "node2"; + nodeName2 = "nodename2"; + discoveryNode2 = new DiscoveryNode( + nodeName2, + node2, + new TransportAddress(TransportAddress.META_ADDRESS, 9301), + emptyMap(), + emptySet(), + Version.CURRENT + ); + + modelSize = 712480L; + model1Id = "A69pa3UBHuCbh-emo9oR_entity_host1"; + model0Id = "A69pa3UBHuCbh-emo9oR_entity_host0"; + + shingleSize = -1; + + String clusterName = "test-cluster-name"; + + Map modelSizeMap1 = new HashMap() { + { + put(model1Id, modelSize); + } + }; + + Map modelSizeMap2 = new HashMap() { + { + put(model0Id, modelSize); + } + }; + + // one model in each node; all fully initialized + long updates = requiredSamples - 1; + if (InittedEverResultStatus.INITTED == initted) { + updates = requiredSamples + 1; + } + ProfileNodeResponse profileNodeResponse1 = new ProfileNodeResponse( + discoveryNode1, + modelSizeMap1, + shingleSize, + 1L, + updates, + new ArrayList<>(), + modelSizeMap1.size() + ); + ProfileNodeResponse profileNodeResponse2 = new ProfileNodeResponse( + discoveryNode2, + modelSizeMap2, + shingleSize, + 1L, + updates, + new ArrayList<>(), + modelSizeMap2.size() + ); + List profileNodeResponses = Arrays.asList(profileNodeResponse1, profileNodeResponse2); + List failures = Collections.emptyList(); + ProfileResponse profileResponse = new ProfileResponse(new ClusterName(clusterName), profileNodeResponses, failures); + + listener.onResponse(profileResponse); + + return null; + }).when(client).execute(any(ProfileAction.class), any(), any()); + + } + + @SuppressWarnings("unchecked") + private void setUpClientSearch(InittedEverResultStatus inittedEverResultStatus) { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + SearchRequest request = (SearchRequest) args[0]; + ActionListener listener = (ActionListener) args[1]; + + AnomalyResult result = null; + if (request.source().query().toString().contains(AnomalyResult.ANOMALY_SCORE_FIELD)) { + switch (inittedEverResultStatus) { + case INITTED: + result = TestHelpers.randomAnomalyDetectResult(0.87); + listener.onResponse(TestHelpers.createSearchResponse(result)); + break; + case NOT_INITTED: + listener.onResponse(TestHelpers.createEmptySearchResponse()); + break; + default: + assertTrue("should not reach here", false); + break; + } + } + + return null; + }).when(client).search(any(), any()); + } + + public void testInit() throws InterruptedException { + setUpClientExecuteProfileAction(InittedEverResultStatus.NOT_INITTED); + setUpClientSearch(InittedEverResultStatus.NOT_INITTED); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.INIT).build(); + runner.profile(detectorId, ActionListener.wrap(response -> { + assertEquals(expectedProfile, response); + inProgressLatch.countDown(); + }, exception -> { + assertTrue("Should not reach here", false); + inProgressLatch.countDown(); + }), stateNError); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + public void testRunning() throws InterruptedException { + setUpClientExecuteProfileAction(InittedEverResultStatus.INITTED); + setUpClientSearch(InittedEverResultStatus.INITTED); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.RUNNING).build(); + runner.profile(detectorId, ActionListener.wrap(response -> { + assertEquals(expectedProfile, response); + inProgressLatch.countDown(); + }, exception -> { + assertTrue("Should not reach here", false); + inProgressLatch.countDown(); + }), stateNError); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + /** + * Although profile action results indicate initted, we trust what result index tells us + * @throws InterruptedException if CountDownLatch is interrupted while waiting + */ + public void testResultIndexFinalTruth() throws InterruptedException { + setUpClientExecuteProfileAction(InittedEverResultStatus.NOT_INITTED); + setUpClientSearch(InittedEverResultStatus.INITTED); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.RUNNING).build(); + runner.profile(detectorId, ActionListener.wrap(response -> { + assertEquals(expectedProfile, response); + inProgressLatch.countDown(); + }, exception -> { + assertTrue("Should not reach here", false); + inProgressLatch.countDown(); + }), stateNError); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } +} diff --git a/src/test/java/org/opensearch/ad/NodeStateManagerTests.java-e b/src/test/java/org/opensearch/ad/NodeStateManagerTests.java-e new file mode 100644 index 000000000..9cad7d5eb --- /dev/null +++ b/src/test/java/org/opensearch/ad/NodeStateManagerTests.java-e @@ -0,0 +1,455 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; + +import java.io.IOException; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Locale; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; + +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.transport.AnomalyResultTests; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.ad.util.Throttler; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodeRole; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.search.SearchModule; +import org.opensearch.test.ClusterServiceUtils; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonName; + +import com.google.common.collect.ImmutableMap; + +public class NodeStateManagerTests extends AbstractTimeSeriesTest { + private NodeStateManager stateManager; + private Client client; + private ClientUtil clientUtil; + private Clock clock; + private Duration duration; + private Throttler throttler; + private ThreadPool context; + private AnomalyDetector detectorToCheck; + private Settings settings; + private String adId = "123"; + private String nodeId = "123"; + + private GetResponse checkpointResponse; + private ClusterService clusterService; + private ClusterSettings clusterSettings; + private AnomalyDetectorJob jobToCheck; + + @Override + protected NamedXContentRegistry xContentRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, Collections.emptyList()); + return new NamedXContentRegistry(searchModule.getNamedXContents()); + } + + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(AnomalyResultTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + client = mock(Client.class); + settings = Settings + .builder() + .put("plugins.anomaly_detection.max_retry_for_unresponsive_node", 3) + .put("plugins.anomaly_detection.ad_mute_minutes", TimeValue.timeValueMinutes(10)) + .build(); + clock = mock(Clock.class); + duration = Duration.ofHours(1); + context = TestHelpers.createThreadPool(); + throttler = new Throttler(clock); + + clientUtil = new ClientUtil(Settings.EMPTY, client, throttler, mock(ThreadPool.class)); + Set> nodestateSetting = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + nodestateSetting.add(MAX_RETRY_FOR_UNRESPONSIVE_NODE); + nodestateSetting.add(BACKOFF_MINUTES); + clusterSettings = new ClusterSettings(Settings.EMPTY, nodestateSetting); + + DiscoveryNode discoveryNode = new DiscoveryNode( + "node1", + OpenSearchTestCase.buildNewFakeTransportAddress(), + Collections.emptyMap(), + DiscoveryNodeRole.BUILT_IN_ROLES, + Version.CURRENT + ); + + clusterService = ClusterServiceUtils.createClusterService(threadPool, discoveryNode, clusterSettings); + stateManager = new NodeStateManager(client, xContentRegistry(), settings, clientUtil, clock, duration, clusterService); + + checkpointResponse = mock(GetResponse.class); + jobToCheck = TestHelpers.randomAnomalyDetectorJob(true, Instant.ofEpochMilli(1602401500000L), null); + } + + @Override + @After + public void tearDown() throws Exception { + super.tearDown(); + stateManager = null; + client = null; + clientUtil = null; + detectorToCheck = null; + } + + @SuppressWarnings("unchecked") + private String setupDetector() throws IOException { + detectorToCheck = TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null, true); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length >= 2 + ); + + GetRequest request = null; + ActionListener listener = null; + if (args[0] instanceof GetRequest) { + request = (GetRequest) args[0]; + } + if (args[1] instanceof ActionListener) { + listener = (ActionListener) args[1]; + } + + assertTrue(request != null && listener != null); + listener.onResponse(TestHelpers.createGetResponse(detectorToCheck, detectorToCheck.getId(), CommonName.CONFIG_INDEX)); + + return null; + }).when(client).get(any(), any(ActionListener.class)); + return detectorToCheck.getId(); + } + + @SuppressWarnings("unchecked") + private void setupCheckpoint(boolean responseExists) throws IOException { + when(checkpointResponse.isExists()).thenReturn(responseExists); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length >= 2 + ); + + GetRequest request = null; + ActionListener listener = null; + if (args[0] instanceof GetRequest) { + request = (GetRequest) args[0]; + } + if (args[1] instanceof ActionListener) { + listener = (ActionListener) args[1]; + } + + assertTrue(request != null && listener != null); + listener.onResponse(checkpointResponse); + + return null; + }).when(client).get(any(), any(ActionListener.class)); + } + + public void testGetLastError() throws IOException, InterruptedException { + String error = "blah"; + assertEquals(NodeStateManager.NO_ERROR, stateManager.getLastDetectionError(adId)); + stateManager.setLastDetectionError(adId, error); + assertEquals(error, stateManager.getLastDetectionError(adId)); + } + + public void testShouldMute() { + assertTrue(!stateManager.isMuted(nodeId, adId)); + + when(clock.millis()).thenReturn(10000L); + IntStream.range(0, 4).forEach(j -> stateManager.addPressure(nodeId, adId)); + + when(clock.millis()).thenReturn(20000L); + assertTrue(stateManager.isMuted(nodeId, adId)); + + // > 15 minutes have passed, we should not mute anymore + when(clock.millis()).thenReturn(1000001L); + assertTrue(!stateManager.isMuted(nodeId, adId)); + + // the backpressure counter should be reset + when(clock.millis()).thenReturn(100001L); + stateManager.resetBackpressureCounter(nodeId, adId); + assertTrue(!stateManager.isMuted(nodeId, adId)); + } + + public void testMaintenanceDoNothing() { + stateManager.maintenance(); + + verifyZeroInteractions(clock); + } + + public void testHasRunningQuery() throws IOException { + stateManager = new NodeStateManager( + client, + xContentRegistry(), + settings, + new ClientUtil(settings, client, throttler, context), + clock, + duration, + clusterService + ); + + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of(), null); + SearchRequest dummySearchRequest = new SearchRequest(); + assertFalse(stateManager.hasRunningQuery(detector)); + throttler.insertFilteredQuery(detector.getId(), dummySearchRequest); + assertTrue(stateManager.hasRunningQuery(detector)); + } + + public void testGetAnomalyDetector() throws IOException, InterruptedException { + String detectorId = setupDetector(); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + stateManager.getAnomalyDetector(detectorId, ActionListener.wrap(asDetector -> { + assertEquals(detectorToCheck, asDetector.get()); + inProgressLatch.countDown(); + }, exception -> { + assertTrue(false); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + /** + * Test that we caches anomaly detector definition after the first call + * @throws IOException if client throws exception + * @throws InterruptedException if the current thread is interrupted while waiting + */ + @SuppressWarnings("unchecked") + public void testRepeatedGetAnomalyDetector() throws IOException, InterruptedException { + String detectorId = setupDetector(); + final CountDownLatch inProgressLatch = new CountDownLatch(2); + + stateManager.getAnomalyDetector(detectorId, ActionListener.wrap(asDetector -> { + assertEquals(detectorToCheck, asDetector.get()); + inProgressLatch.countDown(); + }, exception -> { + assertTrue(false); + inProgressLatch.countDown(); + })); + + stateManager.getAnomalyDetector(detectorId, ActionListener.wrap(asDetector -> { + assertEquals(detectorToCheck, asDetector.get()); + inProgressLatch.countDown(); + }, exception -> { + assertTrue(false); + inProgressLatch.countDown(); + })); + + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + + verify(client, times(1)).get(any(), any(ActionListener.class)); + } + + public void getCheckpointTestTemplate(boolean exists) throws IOException { + setupCheckpoint(exists); + when(clock.instant()).thenReturn(Instant.MIN); + stateManager + .getDetectorCheckpoint(adId, ActionListener.wrap(checkpointExists -> { assertEquals(exists, checkpointExists); }, exception -> { + for (StackTraceElement ste : exception.getStackTrace()) { + logger.info(ste); + } + assertTrue(false); + })); + } + + public void testCheckpointExists() throws IOException { + getCheckpointTestTemplate(true); + } + + public void testCheckpointNotExists() throws IOException { + getCheckpointTestTemplate(false); + } + + public void testMaintenanceNotRemove() throws IOException { + setupCheckpoint(true); + when(clock.instant()).thenReturn(Instant.ofEpochMilli(1)); + stateManager + .getDetectorCheckpoint( + adId, + ActionListener.wrap(gotCheckpoint -> { assertTrue(gotCheckpoint); }, exception -> assertTrue(false)) + ); + when(clock.instant()).thenReturn(Instant.ofEpochMilli(1)); + stateManager.maintenance(); + stateManager + .getDetectorCheckpoint(adId, ActionListener.wrap(gotCheckpoint -> assertTrue(gotCheckpoint), exception -> assertTrue(false))); + verify(client, times(1)).get(any(), any()); + } + + public void testMaintenanceRemove() throws IOException { + setupCheckpoint(true); + when(clock.instant()).thenReturn(Instant.ofEpochMilli(1)); + stateManager + .getDetectorCheckpoint( + adId, + ActionListener.wrap(gotCheckpoint -> { assertTrue(gotCheckpoint); }, exception -> assertTrue(false)) + ); + when(clock.instant()).thenReturn(Instant.ofEpochSecond(7200L)); + stateManager.maintenance(); + stateManager + .getDetectorCheckpoint( + adId, + ActionListener.wrap(gotCheckpoint -> { assertTrue(gotCheckpoint); }, exception -> assertTrue(false)) + ); + verify(client, times(2)).get(any(), any()); + } + + public void testColdStartRunning() { + assertTrue(!stateManager.isColdStartRunning(adId)); + stateManager.markColdStartRunning(adId); + assertTrue(stateManager.isColdStartRunning(adId)); + } + + public void testSettingUpdateMaxRetry() { + when(clock.millis()).thenReturn(System.currentTimeMillis()); + stateManager.addPressure(nodeId, adId); + // In setUp method, we mute after 3 tries + assertTrue(!stateManager.isMuted(nodeId, adId)); + + Settings newSettings = Settings.builder().put(AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE.getKey(), "1").build(); + Settings.Builder target = Settings.builder(); + clusterSettings.updateDynamicSettings(newSettings, target, Settings.builder(), "test"); + clusterSettings.applySettings(target.build()); + stateManager.addPressure(nodeId, adId); + // since we have one violation and the max is 1, this is flagged as muted + assertTrue(stateManager.isMuted(nodeId, adId)); + } + + public void testSettingUpdateBackOffMin() { + when(clock.millis()).thenReturn(1000L); + // In setUp method, we mute after 3 tries + for (int i = 0; i < 4; i++) { + stateManager.addPressure(nodeId, adId); + } + + assertTrue(stateManager.isMuted(nodeId, adId)); + + Settings newSettings = Settings.builder().put(AnomalyDetectorSettings.BACKOFF_MINUTES.getKey(), "1m").build(); + Settings.Builder target = Settings.builder(); + clusterSettings.updateDynamicSettings(newSettings, target, Settings.builder(), "test"); + clusterSettings.applySettings(target.build()); + stateManager.addPressure(nodeId, adId); + // move the clobk by 1000 milliseconds + // when evaluating isMuted, 62000 - 1000 (last mute time) > 60000, which + // make isMuted true + when(clock.millis()).thenReturn(62000L); + assertTrue(!stateManager.isMuted(nodeId, adId)); + } + + @SuppressWarnings("unchecked") + private String setupJob() throws IOException { + String detectorId = jobToCheck.getName(); + + doAnswer(invocation -> { + GetRequest request = invocation.getArgument(0); + ActionListener listener = invocation.getArgument(1); + if (request.index().equals(CommonName.JOB_INDEX)) { + listener.onResponse(TestHelpers.createGetResponse(jobToCheck, detectorId, CommonName.JOB_INDEX)); + } + return null; + }).when(client).get(any(), any(ActionListener.class)); + + return detectorId; + } + + public void testGetAnomalyJob() throws IOException, InterruptedException { + String detectorId = setupJob(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + stateManager.getAnomalyDetectorJob(detectorId, ActionListener.wrap(asDetector -> { + assertEquals(jobToCheck, asDetector.get()); + inProgressLatch.countDown(); + }, exception -> { + assertTrue(false); + inProgressLatch.countDown(); + })); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + /** + * Test that we caches anomaly detector job definition after the first call + * @throws IOException if client throws exception + * @throws InterruptedException if the current thread is interrupted while waiting + */ + @SuppressWarnings("unchecked") + public void testRepeatedGetAnomalyJob() throws IOException, InterruptedException { + String detectorId = setupJob(); + final CountDownLatch inProgressLatch = new CountDownLatch(2); + + stateManager.getAnomalyDetectorJob(detectorId, ActionListener.wrap(asDetector -> { + assertEquals(jobToCheck, asDetector.get()); + inProgressLatch.countDown(); + }, exception -> { + assertTrue(false); + inProgressLatch.countDown(); + })); + + stateManager.getAnomalyDetectorJob(detectorId, ActionListener.wrap(asDetector -> { + assertEquals(jobToCheck, asDetector.get()); + inProgressLatch.countDown(); + }, exception -> { + assertTrue(false); + inProgressLatch.countDown(); + })); + + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + + verify(client, times(1)).get(any(), any(ActionListener.class)); + } +} diff --git a/src/test/java/org/opensearch/ad/NodeStateTests.java-e b/src/test/java/org/opensearch/ad/NodeStateTests.java-e new file mode 100644 index 000000000..c48afdb76 --- /dev/null +++ b/src/test/java/org/opensearch/ad/NodeStateTests.java-e @@ -0,0 +1,104 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; + +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.common.exception.TimeSeriesException; + +public class NodeStateTests extends OpenSearchTestCase { + private NodeState state; + private Clock clock; + + @Override + public void setUp() throws Exception { + super.setUp(); + clock = mock(Clock.class); + state = new NodeState("123", clock); + } + + private Duration duration = Duration.ofHours(1); + + public void testMaintenanceNotRemoveSingle() throws IOException { + when(clock.instant()).thenReturn(Instant.ofEpochMilli(1000)); + state.setDetectorDef(TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null)); + + when(clock.instant()).thenReturn(Instant.MIN); + assertTrue(!state.expired(duration)); + } + + public void testMaintenanceNotRemove() throws IOException { + when(clock.instant()).thenReturn(Instant.ofEpochSecond(1000)); + state.setDetectorDef(TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null)); + state.setLastDetectionError(null); + + when(clock.instant()).thenReturn(Instant.ofEpochSecond(3700)); + assertTrue(!state.expired(duration)); + } + + public void testMaintenanceRemoveLastError() throws IOException { + when(clock.instant()).thenReturn(Instant.ofEpochMilli(1000)); + state + .setDetectorDef( + + TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null) + ); + state.setLastDetectionError(null); + + when(clock.instant()).thenReturn(Instant.ofEpochSecond(3700)); + assertTrue(state.expired(duration)); + } + + public void testMaintenancRemoveDetector() throws IOException { + when(clock.instant()).thenReturn(Instant.MIN); + state.setDetectorDef(TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null)); + when(clock.instant()).thenReturn(Instant.MAX); + assertTrue(state.expired(duration)); + + } + + public void testMaintenanceFlagNotRemove() throws IOException { + when(clock.instant()).thenReturn(Instant.ofEpochMilli(1000)); + state.setCheckpointExists(true); + when(clock.instant()).thenReturn(Instant.MIN); + assertTrue(!state.expired(duration)); + } + + public void testMaintenancFlagRemove() throws IOException { + when(clock.instant()).thenReturn(Instant.MIN); + state.setCheckpointExists(true); + when(clock.instant()).thenReturn(Instant.MIN); + assertTrue(!state.expired(duration)); + } + + public void testMaintenanceLastColdStartRemoved() { + when(clock.instant()).thenReturn(Instant.ofEpochMilli(1000)); + state.setException(new TimeSeriesException("123", "")); + when(clock.instant()).thenReturn(Instant.ofEpochSecond(3700)); + assertTrue(state.expired(duration)); + } + + public void testMaintenanceLastColdStartNotRemoved() { + when(clock.instant()).thenReturn(Instant.ofEpochMilli(1_000_000L)); + state.setException(new TimeSeriesException("123", "")); + when(clock.instant()).thenReturn(Instant.ofEpochSecond(3700)); + assertTrue(!state.expired(duration)); + } +} diff --git a/src/test/java/org/opensearch/ad/ODFERestTestCase.java b/src/test/java/org/opensearch/ad/ODFERestTestCase.java index 8ed1f86af..e7b69e388 100644 --- a/src/test/java/org/opensearch/ad/ODFERestTestCase.java +++ b/src/test/java/org/opensearch/ad/ODFERestTestCase.java @@ -61,10 +61,10 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.rest.SecureRestClientBuilder; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.rest.RestStatus; import org.opensearch.test.rest.OpenSearchRestTestCase; import com.google.gson.JsonArray; diff --git a/src/test/java/org/opensearch/ad/ODFERestTestCase.java-e b/src/test/java/org/opensearch/ad/ODFERestTestCase.java-e new file mode 100644 index 000000000..f76c82053 --- /dev/null +++ b/src/test/java/org/opensearch/ad/ODFERestTestCase.java-e @@ -0,0 +1,292 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad; + +import static org.opensearch.client.RestClientBuilder.DEFAULT_MAX_CONN_PER_ROUTE; +import static org.opensearch.client.RestClientBuilder.DEFAULT_MAX_CONN_TOTAL; +import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_ENABLED; +import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH; +import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_KEYPASSWORD; +import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_PASSWORD; +import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_PEMCERT_FILEPATH; + +import java.io.IOException; +import java.io.InputStreamReader; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.Charset; +import java.nio.file.Path; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +import javax.net.ssl.SSLEngine; + +import org.apache.hc.client5.http.auth.AuthScope; +import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; +import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider; +import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManager; +import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManagerBuilder; +import org.apache.hc.client5.http.ssl.ClientTlsStrategyBuilder; +import org.apache.hc.client5.http.ssl.NoopHostnameVerifier; +import org.apache.hc.core5.function.Factory; +import org.apache.hc.core5.http.Header; +import org.apache.hc.core5.http.HttpHost; +import org.apache.hc.core5.http.message.BasicHeader; +import org.apache.hc.core5.http.nio.ssl.TlsStrategy; +import org.apache.hc.core5.reactor.ssl.TlsDetails; +import org.apache.hc.core5.ssl.SSLContextBuilder; +import org.apache.hc.core5.util.Timeout; +import org.junit.After; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.client.RestClient; +import org.opensearch.client.RestClientBuilder; +import org.opensearch.common.io.PathUtils; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.rest.SecureRestClientBuilder; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.test.rest.OpenSearchRestTestCase; + +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; + +/** + * ODFE integration test base class to support both security disabled and enabled ODFE cluster. + */ +public abstract class ODFERestTestCase extends OpenSearchRestTestCase { + + protected boolean isHttps() { + boolean isHttps = Optional.ofNullable(System.getProperty("https")).map("true"::equalsIgnoreCase).orElse(false); + if (isHttps) { + // currently only external cluster is supported for security enabled testing + if (!Optional.ofNullable(System.getProperty("tests.rest.cluster")).isPresent()) { + throw new RuntimeException("cluster url should be provided for security enabled testing"); + } + } + + return isHttps; + } + + @Override + protected String getProtocol() { + return isHttps() ? "https" : "http"; + } + + @Override + protected Settings restAdminSettings() { + return Settings + .builder() + // disable the warning exception for admin client since it's only used for cleanup. + .put("strictDeprecationMode", false) + .put("http.port", 9200) + .put(OPENSEARCH_SECURITY_SSL_HTTP_ENABLED, isHttps()) + .put(OPENSEARCH_SECURITY_SSL_HTTP_PEMCERT_FILEPATH, "sample.pem") + .put(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH, "test-kirk.jks") + .put(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_PASSWORD, "changeit") + .put(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_KEYPASSWORD, "changeit") + .build(); + } + + // Utility fn for deleting indices. Should only be used when not allowed in a regular context + // (e.g., deleting system indices) + protected static void deleteIndexWithAdminClient(String name) throws IOException { + Request request = new Request("DELETE", "/" + name); + adminClient().performRequest(request); + } + + // Utility fn for checking if an index exists. Should only be used when not allowed in a regular context + // (e.g., checking existence of system indices) + protected static boolean indexExistsWithAdminClient(String indexName) throws IOException { + Request request = new Request("HEAD", "/" + indexName); + Response response = adminClient().performRequest(request); + return RestStatus.OK.getStatus() == response.getStatusLine().getStatusCode(); + } + + @Override + protected RestClient buildClient(Settings settings, HttpHost[] hosts) throws IOException { + boolean strictDeprecationMode = settings.getAsBoolean("strictDeprecationMode", true); + RestClientBuilder builder = RestClient.builder(hosts); + if (isHttps()) { + String keystore = settings.get(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH); + if (Objects.nonNull(keystore)) { + URI uri = null; + try { + uri = this.getClass().getClassLoader().getResource("security/sample.pem").toURI(); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + Path configPath = PathUtils.get(uri).getParent().toAbsolutePath(); + return new SecureRestClientBuilder(settings, configPath).build(); + } else { + configureHttpsClient(builder, settings); + builder.setStrictDeprecationMode(strictDeprecationMode); + return builder.build(); + } + + } else { + configureClient(builder, settings); + builder.setStrictDeprecationMode(strictDeprecationMode); + return builder.build(); + } + + } + + @SuppressWarnings("unchecked") + @After + protected void wipeAllODFEIndices() throws IOException { + Response response = adminClient().performRequest(new Request("GET", "/_cat/indices?format=json&expand_wildcards=all")); + XContentType xContentType = XContentType.fromMediaType(response.getEntity().getContentType()); + try ( + XContentParser parser = xContentType + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + response.getEntity().getContent() + ) + ) { + XContentParser.Token token = parser.nextToken(); + List> parserList = null; + if (token == XContentParser.Token.START_ARRAY) { + parserList = parser.listOrderedMap().stream().map(obj -> (Map) obj).collect(Collectors.toList()); + } else { + parserList = Collections.singletonList(parser.mapOrdered()); + } + + for (Map index : parserList) { + String indexName = (String) index.get("index"); + if (indexName != null && !".opendistro_security".equals(indexName)) { + adminClient().performRequest(new Request("DELETE", "/" + indexName)); + } + } + } + } + + protected static void configureHttpsClient(RestClientBuilder builder, Settings settings) throws IOException { + Map headers = ThreadContext.buildDefaultHeaders(settings); + Header[] defaultHeaders = new Header[headers.size()]; + int i = 0; + for (Map.Entry entry : headers.entrySet()) { + defaultHeaders[i++] = new BasicHeader(entry.getKey(), entry.getValue()); + } + builder.setDefaultHeaders(defaultHeaders); + builder.setHttpClientConfigCallback(httpClientBuilder -> { + String userName = Optional + .ofNullable(System.getProperty("user")) + .orElseThrow(() -> new RuntimeException("user name is missing")); + String password = Optional + .ofNullable(System.getProperty("password")) + .orElseThrow(() -> new RuntimeException("password is missing")); + BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider(); + final AuthScope anyScope = new AuthScope(null, -1); + credentialsProvider.setCredentials(anyScope, new UsernamePasswordCredentials(userName, password.toCharArray())); + try { + final TlsStrategy tlsStrategy = ClientTlsStrategyBuilder + .create() + .setHostnameVerifier(NoopHostnameVerifier.INSTANCE) + .setSslContext(SSLContextBuilder.create().loadTrustMaterial(null, (chains, authType) -> true).build()) + // See https://issues.apache.org/jira/browse/HTTPCLIENT-2219 + .setTlsDetailsFactory(new Factory() { + @Override + public TlsDetails create(final SSLEngine sslEngine) { + return new TlsDetails(sslEngine.getSession(), sslEngine.getApplicationProtocol()); + } + }) + .build(); + final PoolingAsyncClientConnectionManager connectionManager = PoolingAsyncClientConnectionManagerBuilder + .create() + .setMaxConnPerRoute(DEFAULT_MAX_CONN_PER_ROUTE) + .setMaxConnTotal(DEFAULT_MAX_CONN_TOTAL) + .setTlsStrategy(tlsStrategy) + .build(); + return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider).setConnectionManager(connectionManager); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + + final String socketTimeoutString = settings.get(CLIENT_SOCKET_TIMEOUT); + final TimeValue socketTimeout = TimeValue + .parseTimeValue(socketTimeoutString == null ? "60s" : socketTimeoutString, CLIENT_SOCKET_TIMEOUT); + builder.setRequestConfigCallback(conf -> { + Timeout timeout = Timeout.ofMilliseconds(Math.toIntExact(socketTimeout.getMillis())); + conf.setConnectTimeout(timeout); + conf.setResponseTimeout(timeout); + return conf; + }); + if (settings.hasValue(CLIENT_PATH_PREFIX)) { + builder.setPathPrefix(settings.get(CLIENT_PATH_PREFIX)); + } + } + + /** + * wipeAllIndices won't work since it cannot delete security index. Use wipeAllODFEIndices instead. + */ + @Override + protected boolean preserveIndicesUponCompletion() { + return true; + } + + protected void waitAllSyncheticDataIngested(int expectedSize, String datasetName, RestClient client) throws Exception { + int maxWaitCycles = 3; + do { + Request request = new Request("POST", String.format(Locale.ROOT, "/%s/_search", datasetName)); + request + .setJsonEntity( + String + .format( + Locale.ROOT, + "{\"query\": {" + + " \"match_all\": {}" + + " }," + + " \"size\": 1," + + " \"sort\": [" + + " {" + + " \"timestamp\": {" + + " \"order\": \"desc\"" + + " }" + + " }" + + " ]}" + ) + ); + // Make sure all of the test data has been ingested + // Expected response: + // "_index":"synthetic","_type":"_doc","_id":"10080","_score":null,"_source":{"timestamp":"2019-11-08T00:00:00Z","Feature1":156.30028000000001,"Feature2":100.211205,"host":"host1"},"sort":[1573171200000]} + Response response = client.performRequest(request); + JsonObject json = JsonParser + .parseReader(new InputStreamReader(response.getEntity().getContent(), Charset.defaultCharset())) + .getAsJsonObject(); + JsonArray hits = json.getAsJsonObject("hits").getAsJsonArray("hits"); + if (hits != null + && hits.size() == 1 + && expectedSize - 1 == hits.get(0).getAsJsonObject().getAsJsonPrimitive("_id").getAsLong()) { + break; + } else { + request = new Request("POST", String.format(Locale.ROOT, "/%s/_refresh", datasetName)); + client.performRequest(request); + } + Thread.sleep(1_000); + } while (maxWaitCycles-- >= 0); + } +} diff --git a/src/test/java/org/opensearch/ad/breaker/ADCircuitBreakerServiceTests.java-e b/src/test/java/org/opensearch/ad/breaker/ADCircuitBreakerServiceTests.java-e new file mode 100644 index 000000000..7a5be47b6 --- /dev/null +++ b/src/test/java/org/opensearch/ad/breaker/ADCircuitBreakerServiceTests.java-e @@ -0,0 +1,115 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.breaker; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Mockito.when; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.monitor.jvm.JvmService; +import org.opensearch.monitor.jvm.JvmStats; + +public class ADCircuitBreakerServiceTests { + + @InjectMocks + private ADCircuitBreakerService adCircuitBreakerService; + + @Mock + JvmService jvmService; + + @Mock + JvmStats jvmStats; + + @Mock + JvmStats.Mem mem; + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + } + + @Test + public void testRegisterBreaker() { + adCircuitBreakerService.registerBreaker(BreakerName.MEM.getName(), new MemoryCircuitBreaker(jvmService)); + CircuitBreaker breaker = adCircuitBreakerService.getBreaker(BreakerName.MEM.getName()); + + assertThat(breaker, is(notNullValue())); + } + + @Test + public void testRegisterBreakerNull() { + CircuitBreaker breaker = adCircuitBreakerService.getBreaker(BreakerName.MEM.getName()); + + assertThat(breaker, is(nullValue())); + } + + @Test + public void testUnregisterBreaker() { + adCircuitBreakerService.registerBreaker(BreakerName.MEM.getName(), new MemoryCircuitBreaker(jvmService)); + CircuitBreaker breaker = adCircuitBreakerService.getBreaker(BreakerName.MEM.getName()); + assertThat(breaker, is(notNullValue())); + adCircuitBreakerService.unregisterBreaker(BreakerName.MEM.getName()); + breaker = adCircuitBreakerService.getBreaker(BreakerName.MEM.getName()); + assertThat(breaker, is(nullValue())); + } + + @Test + public void testUnregisterBreakerNull() { + adCircuitBreakerService.registerBreaker(BreakerName.MEM.getName(), new MemoryCircuitBreaker(jvmService)); + adCircuitBreakerService.unregisterBreaker(null); + CircuitBreaker breaker = adCircuitBreakerService.getBreaker(BreakerName.MEM.getName()); + assertThat(breaker, is(notNullValue())); + } + + @Test + public void testClearBreakers() { + adCircuitBreakerService.registerBreaker(BreakerName.CPU.getName(), new MemoryCircuitBreaker(jvmService)); + CircuitBreaker breaker = adCircuitBreakerService.getBreaker(BreakerName.CPU.getName()); + assertThat(breaker, is(notNullValue())); + adCircuitBreakerService.clearBreakers(); + breaker = adCircuitBreakerService.getBreaker(BreakerName.CPU.getName()); + assertThat(breaker, is(nullValue())); + } + + @Test + public void testInit() { + assertThat(adCircuitBreakerService.init(), is(notNullValue())); + } + + @Test + public void testIsOpen() { + when(jvmService.stats()).thenReturn(jvmStats); + when(jvmStats.getMem()).thenReturn(mem); + when(mem.getHeapUsedPercent()).thenReturn((short) 50); + + adCircuitBreakerService.registerBreaker(BreakerName.MEM.getName(), new MemoryCircuitBreaker(jvmService)); + assertThat(adCircuitBreakerService.isOpen(), equalTo(false)); + } + + @Test + public void testIsOpen1() { + when(jvmService.stats()).thenReturn(jvmStats); + when(jvmStats.getMem()).thenReturn(mem); + when(mem.getHeapUsedPercent()).thenReturn((short) 90); + + adCircuitBreakerService.registerBreaker(BreakerName.MEM.getName(), new MemoryCircuitBreaker(jvmService)); + assertThat(adCircuitBreakerService.isOpen(), equalTo(true)); + } +} diff --git a/src/test/java/org/opensearch/ad/breaker/MemoryCircuitBreakerTests.java-e b/src/test/java/org/opensearch/ad/breaker/MemoryCircuitBreakerTests.java-e new file mode 100644 index 000000000..e9249df82 --- /dev/null +++ b/src/test/java/org/opensearch/ad/breaker/MemoryCircuitBreakerTests.java-e @@ -0,0 +1,73 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.breaker; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Mockito.when; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.monitor.jvm.JvmService; +import org.opensearch.monitor.jvm.JvmStats; + +public class MemoryCircuitBreakerTests { + + @Mock + JvmService jvmService; + + @Mock + JvmStats jvmStats; + + @Mock + JvmStats.Mem mem; + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + when(jvmService.stats()).thenReturn(jvmStats); + when(jvmStats.getMem()).thenReturn(mem); + when(mem.getHeapUsedPercent()).thenReturn((short) 50); + } + + @Test + public void testIsOpen() { + CircuitBreaker breaker = new MemoryCircuitBreaker(jvmService); + + assertThat(breaker.isOpen(), equalTo(false)); + } + + @Test + public void testIsOpen1() { + CircuitBreaker breaker = new MemoryCircuitBreaker((short) 90, jvmService); + + assertThat(breaker.isOpen(), equalTo(false)); + } + + @Test + public void testIsOpen2() { + CircuitBreaker breaker = new MemoryCircuitBreaker(jvmService); + + when(mem.getHeapUsedPercent()).thenReturn((short) 95); + assertThat(breaker.isOpen(), equalTo(true)); + } + + @Test + public void testIsOpen3() { + CircuitBreaker breaker = new MemoryCircuitBreaker((short) 90, jvmService); + + when(mem.getHeapUsedPercent()).thenReturn((short) 95); + assertThat(breaker.isOpen(), equalTo(true)); + } +} diff --git a/src/test/java/org/opensearch/ad/bwc/ADBackwardsCompatibilityIT.java b/src/test/java/org/opensearch/ad/bwc/ADBackwardsCompatibilityIT.java index bda92df36..d1dde1654 100644 --- a/src/test/java/org/opensearch/ad/bwc/ADBackwardsCompatibilityIT.java +++ b/src/test/java/org/opensearch/ad/bwc/ADBackwardsCompatibilityIT.java @@ -48,7 +48,7 @@ import org.opensearch.ad.util.ExceptionUtil; import org.opensearch.client.Response; import org.opensearch.common.settings.Settings; -import org.opensearch.rest.RestStatus; +import org.opensearch.core.rest.RestStatus; import org.opensearch.test.rest.OpenSearchRestTestCase; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.util.RestHandlerUtils; diff --git a/src/test/java/org/opensearch/ad/bwc/ADBackwardsCompatibilityIT.java-e b/src/test/java/org/opensearch/ad/bwc/ADBackwardsCompatibilityIT.java-e new file mode 100644 index 000000000..d1dde1654 --- /dev/null +++ b/src/test/java/org/opensearch/ad/bwc/ADBackwardsCompatibilityIT.java-e @@ -0,0 +1,471 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.ad.bwc; + +import static org.opensearch.ad.rest.ADRestTestUtils.DetectorType.MULTI_CATEGORY_HC_DETECTOR; +import static org.opensearch.ad.rest.ADRestTestUtils.DetectorType.SINGLE_CATEGORY_HC_DETECTOR; +import static org.opensearch.ad.rest.ADRestTestUtils.DetectorType.SINGLE_ENTITY_DETECTOR; +import static org.opensearch.ad.rest.ADRestTestUtils.countADResultOfDetector; +import static org.opensearch.ad.rest.ADRestTestUtils.countDetectors; +import static org.opensearch.ad.rest.ADRestTestUtils.createAnomalyDetector; +import static org.opensearch.ad.rest.ADRestTestUtils.deleteDetector; +import static org.opensearch.ad.rest.ADRestTestUtils.getDetectorWithJobAndTask; +import static org.opensearch.ad.rest.ADRestTestUtils.getDocCountOfIndex; +import static org.opensearch.ad.rest.ADRestTestUtils.ingestTestDataForHistoricalAnalysis; +import static org.opensearch.ad.rest.ADRestTestUtils.searchLatestAdTaskOfDetector; +import static org.opensearch.ad.rest.ADRestTestUtils.startAnomalyDetectorDirectly; +import static org.opensearch.ad.rest.ADRestTestUtils.startHistoricalAnalysis; +import static org.opensearch.ad.rest.ADRestTestUtils.stopHistoricalAnalysis; +import static org.opensearch.ad.rest.ADRestTestUtils.stopRealtimeJob; +import static org.opensearch.ad.rest.ADRestTestUtils.waitUntilTaskDone; +import static org.opensearch.timeseries.util.RestHandlerUtils.ANOMALY_DETECTOR_JOB; +import static org.opensearch.timeseries.util.RestHandlerUtils.HISTORICAL_ANALYSIS_TASK; +import static org.opensearch.timeseries.util.RestHandlerUtils.REALTIME_TASK; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.hc.core5.http.HttpEntity; +import org.junit.Assert; +import org.junit.Before; +import org.opensearch.ad.mock.model.MockSimpleLog; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.rest.ADRestTestUtils; +import org.opensearch.ad.util.ExceptionUtil; +import org.opensearch.client.Response; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.test.rest.OpenSearchRestTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.util.RestHandlerUtils; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +public class ADBackwardsCompatibilityIT extends OpenSearchRestTestCase { + + private static final ClusterType CLUSTER_TYPE = ClusterType.parse(System.getProperty("tests.rest.bwcsuite")); + private static final String CLUSTER_NAME = System.getProperty("tests.clustername"); + private static final String MIXED_CLUSTER_TEST_ROUND = System.getProperty("tests.rest.bwcsuite_round"); + private String dataIndexName = "test_data_for_ad_plugin"; + private int detectionIntervalInMinutes = 1; + private int windowDelayIntervalInMinutes = 1; + private String aggregationMethod = "sum"; + private int totalDocsPerCategory = 10_000; + private int categoryFieldSize = 2; + private List runningRealtimeDetectors; + private List historicalDetectors; + + @Before + public void setUp() throws Exception { + super.setUp(); + this.runningRealtimeDetectors = new ArrayList<>(); + this.historicalDetectors = new ArrayList<>(); + } + + @Override + protected final boolean preserveIndicesUponCompletion() { + return true; + } + + @Override + protected final boolean preserveReposUponCompletion() { + return true; + } + + @Override + protected boolean preserveTemplatesUponCompletion() { + return true; + } + + @Override + protected final Settings restClientSettings() { + return Settings + .builder() + .put(super.restClientSettings()) + // increase the timeout here to 90 seconds to handle long waits for a green + // cluster health. the waits for green need to be longer than a minute to + // account for delayed shards + .put(OpenSearchRestTestCase.CLIENT_SOCKET_TIMEOUT, "90s") + .build(); + } + + private enum ClusterType { + OLD, + MIXED, + UPGRADED; + + public static ClusterType parse(String value) { + switch (value) { + case "old_cluster": + return OLD; + case "mixed_cluster": + return MIXED; + case "upgraded_cluster": + return UPGRADED; + default: + throw new AssertionError("unknown cluster type: " + value); + } + } + } + + @SuppressWarnings("unchecked") + public void testBackwardsCompatibility() throws Exception { + String uri = getUri(); + Map> responseMap = (Map>) getAsMap(uri).get("nodes"); + for (Map response : responseMap.values()) { + List> plugins = (List>) response.get("plugins"); + Set pluginNames = plugins.stream().map(map -> map.get("name")).collect(Collectors.toSet()); + switch (CLUSTER_TYPE) { + case OLD: + // Ingest test data + ingestTestDataForHistoricalAnalysis( + client(), + dataIndexName, + detectionIntervalInMinutes, + true, + 10, + totalDocsPerCategory, + categoryFieldSize + ); + assertEquals(totalDocsPerCategory * categoryFieldSize * 2, getDocCountOfIndex(client(), dataIndexName)); + Assert.assertTrue(pluginNames.contains("opensearch-anomaly-detection")); + Assert.assertTrue(pluginNames.contains("opensearch-job-scheduler")); + + // Create single entity detector and start realtime job + createRealtimeAnomalyDetectorsAndStart(SINGLE_ENTITY_DETECTOR); + + // Create single category HC detector and start realtime job + createRealtimeAnomalyDetectorsAndStart(SINGLE_CATEGORY_HC_DETECTOR); + + // Create single entity historical detector and start historical analysis + createHistoricalAnomalyDetectorsAndStart(); + + // Verify cluster has correct number of detectors now + verifyAnomalyDetectorCount(TestHelpers.LEGACY_OPENDISTRO_AD_BASE_DETECTORS_URI, 3); + // Verify cluster has correct number of detectors with search detector API + assertEquals(3, countDetectors(client(), null)); + + // Verify all realtime detectors are running realtime job + verifyAllRealtimeJobsRunning(); + break; + case MIXED: + // TODO: We have no way to specify whether send request to old node or new node now. + // Add more test later when it's possible to specify request node. + Assert.assertTrue(pluginNames.contains("opensearch-anomaly-detection")); + Assert.assertTrue(pluginNames.contains("opensearch-job-scheduler")); + + // Create single entity detector and start realtime job + createRealtimeAnomalyDetectorsAndStart(SINGLE_ENTITY_DETECTOR); + + // Create single category HC detector and start realtime job + createRealtimeAnomalyDetectorsAndStart(SINGLE_CATEGORY_HC_DETECTOR); + + int mixedClusterTestRound = getMixedClusterTestRound(); + int numberOfDetector = 3 + 2 * mixedClusterTestRound; + // Verify cluster has correct number of detectors now + verifyAnomalyDetectorCount(TestHelpers.LEGACY_OPENDISTRO_AD_BASE_DETECTORS_URI, numberOfDetector); + // Verify cluster has correct number of detectors with search detector API + assertEquals(numberOfDetector, countDetectors(client(), null)); + + // Verify all realtime detectors are running realtime job + verifyAllRealtimeJobsRunning(); + break; + case UPGRADED: + // This branch is for testing full upgraded cluster. That means all nodes in cluster are running + // latest AD version. + Assert.assertTrue(pluginNames.contains("opensearch-anomaly-detection")); + Assert.assertTrue(pluginNames.contains("opensearch-job-scheduler")); + + Map detectors = new HashMap<>(); + + // Create single entity detector and start realtime job + List singleEntityDetectorResults = createRealtimeAnomalyDetectorsAndStart(SINGLE_ENTITY_DETECTOR); + detectors.put(SINGLE_ENTITY_DETECTOR, singleEntityDetectorResults.get(0)); + // Start historical analysis for single entity detector + startHistoricalAnalysisOnNewNode(singleEntityDetectorResults.get(0), ADTaskType.HISTORICAL_SINGLE_ENTITY.name()); + + // Create single category HC detector and start realtime job + List singleCategoryHCResults = createRealtimeAnomalyDetectorsAndStart(SINGLE_CATEGORY_HC_DETECTOR); + detectors.put(SINGLE_CATEGORY_HC_DETECTOR, singleEntityDetectorResults.get(0)); + // Start historical analysis for single category HC detector + startHistoricalAnalysisOnNewNode(singleCategoryHCResults.get(0), ADTaskType.HISTORICAL_HC_DETECTOR.name()); + + // Create multi category HC detector and start realtime job + List multiCategoryHCResults = createRealtimeAnomalyDetectorsAndStart(MULTI_CATEGORY_HC_DETECTOR); + detectors.put(MULTI_CATEGORY_HC_DETECTOR, singleEntityDetectorResults.get(0)); + // Start historical analysis for multi category HC detector + startHistoricalAnalysisOnNewNode(multiCategoryHCResults.get(0), ADTaskType.HISTORICAL_HC_DETECTOR.name()); + + // Verify cluster has correct number of detectors now + verifyAnomalyDetectorCount(TestHelpers.AD_BASE_DETECTORS_URI, 6); + // Verify cluster has correct number of detectors with search detector API + assertEquals(6, countDetectors(client(), null)); + + // Start realtime job for historical detector created on old cluster and check realtime job running. + startRealtimeJobForHistoricalDetectorOnNewNode(); + + // Verify all realtime detectors are running realtime job + verifyAllRealtimeJobsRunning(); + // Verify realtime and historical task exists for all running realtime detector + verifyAdTasks(); + + // Stop and delete detector + stopAndDeleteDetectors(); + break; + } + break; + } + } + + private int getMixedClusterTestRound() { + int mixedClusterTestRound = 0; + switch (MIXED_CLUSTER_TEST_ROUND) { + case "first": + mixedClusterTestRound = 1; + break; + case "second": + mixedClusterTestRound = 2; + break; + case "third": + mixedClusterTestRound = 3; + break; + default: + break; + } + return mixedClusterTestRound; + } + + private void verifyAdTasks() throws InterruptedException, IOException { + boolean realtimeTaskMissing = false; + int i = 0; + int maxRetryTimes = 10; + do { + i++; + for (String detectorId : runningRealtimeDetectors) { + Map jobAndTask = getDetectorWithJobAndTask(client(), detectorId); + AnomalyDetectorJob job = (AnomalyDetectorJob) jobAndTask.get(ANOMALY_DETECTOR_JOB); + ADTask historicalTask = (ADTask) jobAndTask.get(HISTORICAL_ANALYSIS_TASK); + ADTask realtimeTask = (ADTask) jobAndTask.get(REALTIME_TASK); + assertTrue(job.isEnabled()); + assertNotNull(historicalTask); + if (realtimeTask == null) { + realtimeTaskMissing = true; + } + if (i >= maxRetryTimes) { + assertNotNull(realtimeTask); + assertFalse(realtimeTask.isDone()); + } + } + if (realtimeTaskMissing) { + Thread.sleep(10_000);// sleep 10 seconds to wait for realtime job to backfill realtime task + } + } while (realtimeTaskMissing && i < maxRetryTimes); + } + + private void stopAndDeleteDetectors() throws Exception { + for (String detectorId : runningRealtimeDetectors) { + deleteRunningDetector(detectorId); + Response stopRealtimeJobResponse = stopRealtimeJob(client(), detectorId); + assertEquals(RestStatus.OK, TestHelpers.restStatus(stopRealtimeJobResponse)); + try { + Response stopHistoricalAnalysisResponse = stopHistoricalAnalysis(client(), detectorId); + assertEquals(RestStatus.OK, TestHelpers.restStatus(stopHistoricalAnalysisResponse)); + } catch (Exception e) { + if (!ExceptionUtil.getErrorMessage(e).contains("No running task found")) { + throw e; + } + } + Map jobAndTask = getDetectorWithJobAndTask(client(), detectorId); + AnomalyDetectorJob job = (AnomalyDetectorJob) jobAndTask.get(ANOMALY_DETECTOR_JOB); + ADTask historicalAdTask = (ADTask) jobAndTask.get(HISTORICAL_ANALYSIS_TASK); + if (!job.isEnabled() && historicalAdTask.isDone()) { + Response deleteDetectorResponse = deleteDetector(client(), detectorId); + assertEquals(RestStatus.OK, TestHelpers.restStatus(deleteDetectorResponse)); + } + } + } + + // This function can only run on new AD version(>=1.1). + // TODO: execute this function on new node in mixed cluster when we have way to send request to specific node. + private void startHistoricalAnalysisOnNewNode(String detectorId, String taskType) throws IOException, InterruptedException { + String taskId = startHistoricalAnalysis(client(), detectorId); + deleteRunningDetector(detectorId); + waitUntilTaskDone(client(), detectorId); + List adTasks = searchLatestAdTaskOfDetector(client(), detectorId, taskType); + assertEquals(1, adTasks.size()); + assertEquals(taskId, adTasks.get(0).getTaskId()); + int adResultCount = countADResultOfDetector(client(), detectorId, taskId); + assertTrue(adResultCount > 0); + } + + // This function can only run on new AD version(>=1.1). + // TODO: execute this function on new node in mixed cluster when we have way to send request to specific node. + private void startRealtimeJobForHistoricalDetectorOnNewNode() throws IOException { + for (String detectorId : historicalDetectors) { + String jobId = startAnomalyDetectorDirectly(client(), detectorId); + assertEquals(detectorId, jobId); + Map jobAndTask = getDetectorWithJobAndTask(client(), detectorId); + AnomalyDetectorJob detectorJob = (AnomalyDetectorJob) jobAndTask.get(ANOMALY_DETECTOR_JOB); + assertTrue(detectorJob.isEnabled()); + runningRealtimeDetectors.add(detectorId); + } + } + + private void verifyAllRealtimeJobsRunning() throws IOException { + for (String detectorId : runningRealtimeDetectors) { + Map jobAndTask = getDetectorWithJobAndTask(client(), detectorId); + AnomalyDetectorJob detectorJob = (AnomalyDetectorJob) jobAndTask.get(ANOMALY_DETECTOR_JOB); + assertTrue(detectorJob.isEnabled()); + } + } + + private String getUri() { + switch (CLUSTER_TYPE) { + case OLD: + return "_nodes/" + CLUSTER_NAME + "-0/plugins"; + case MIXED: + String round = System.getProperty("tests.rest.bwcsuite_round"); + if (round.equals("second")) { + return "_nodes/" + CLUSTER_NAME + "-1/plugins"; + } else if (round.equals("third")) { + return "_nodes/" + CLUSTER_NAME + "-2/plugins"; + } else { + return "_nodes/" + CLUSTER_NAME + "-0/plugins"; + } + case UPGRADED: + return "_nodes/plugins"; + default: + throw new AssertionError("unknown cluster type: " + CLUSTER_TYPE); + } + } + + private List createRealtimeAnomalyDetectorsAndStart(ADRestTestUtils.DetectorType detectorType) throws Exception { + switch (detectorType) { + case SINGLE_ENTITY_DETECTOR: + // Create single flow detector + Response singleFlowDetectorResponse = createAnomalyDetector( + client(), + dataIndexName, + MockSimpleLog.TIME_FIELD, + detectionIntervalInMinutes, + windowDelayIntervalInMinutes, + MockSimpleLog.VALUE_FIELD, + aggregationMethod, + null, + null + ); + return startAnomalyDetector(singleFlowDetectorResponse, false); + case SINGLE_CATEGORY_HC_DETECTOR: + // Create single category HC detector + Response singleCategoryHCDetectorResponse = createAnomalyDetector( + client(), + dataIndexName, + MockSimpleLog.TIME_FIELD, + detectionIntervalInMinutes, + windowDelayIntervalInMinutes, + MockSimpleLog.VALUE_FIELD, + aggregationMethod, + null, + ImmutableList.of(MockSimpleLog.CATEGORY_FIELD) + ); + return startAnomalyDetector(singleCategoryHCDetectorResponse, false); + case MULTI_CATEGORY_HC_DETECTOR: + // Create multi-category HC detector + Response multiCategoryHCDetectorResponse = createAnomalyDetector( + client(), + dataIndexName, + MockSimpleLog.TIME_FIELD, + detectionIntervalInMinutes, + windowDelayIntervalInMinutes, + MockSimpleLog.VALUE_FIELD, + aggregationMethod, + null, + ImmutableList.of(MockSimpleLog.IP_FIELD, MockSimpleLog.CATEGORY_FIELD) + ); + return startAnomalyDetector(multiCategoryHCDetectorResponse, false); + default: + return null; + } + } + + private List createHistoricalAnomalyDetectorsAndStart() throws Exception { + // only support single entity for historical detector + Response historicalSingleFlowDetectorResponse = createAnomalyDetector( + client(), + dataIndexName, + MockSimpleLog.TIME_FIELD, + detectionIntervalInMinutes, + windowDelayIntervalInMinutes, + MockSimpleLog.VALUE_FIELD, + aggregationMethod, + null, + null, + true + ); + return startAnomalyDetector(historicalSingleFlowDetectorResponse, true); + } + + private void deleteRunningDetector(String detectorId) { + try { + deleteDetector(client(), detectorId); + } catch (Exception e) { + assertTrue(ExceptionUtil.getErrorMessage(e).contains("running")); + } + } + + private List startAnomalyDetector(Response response, boolean historicalDetector) throws IOException { + // verify that the detector is created + assertEquals("Create anomaly detector failed", RestStatus.CREATED, TestHelpers.restStatus(response)); + Map responseMap = entityAsMap(response); + String detectorId = (String) responseMap.get("_id"); + int version = (int) responseMap.get("_version"); + assertNotEquals("response is missing Id", AnomalyDetector.NO_ID, detectorId); + assertTrue("incorrect version", version > 0); + + Response startDetectorResponse = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.LEGACY_OPENDISTRO_AD_BASE_DETECTORS_URI + "/" + detectorId + "/_start", + ImmutableMap.of(), + (HttpEntity) null, + null + ); + Map startDetectorResponseMap = responseAsMap(startDetectorResponse); + String taskOrJobId = (String) startDetectorResponseMap.get("_id"); + assertNotNull(taskOrJobId); + + if (!historicalDetector) { + Map jobAndTask = getDetectorWithJobAndTask(client(), detectorId); + AnomalyDetectorJob job = (AnomalyDetectorJob) jobAndTask.get(ANOMALY_DETECTOR_JOB); + assertTrue(job.isEnabled()); + runningRealtimeDetectors.add(detectorId); + } else { + historicalDetectors.add(detectorId); + } + return ImmutableList.of(detectorId, taskOrJobId); + } + + private void verifyAnomalyDetectorCount(String uri, long expectedCount) throws Exception { + Response response = TestHelpers.makeRequest(client(), "GET", uri + "/" + RestHandlerUtils.COUNT, null, "", null); + Map responseMap = entityAsMap(response); + Integer count = (Integer) responseMap.get("count"); + assertEquals(expectedCount, (long) count); + } + +} diff --git a/src/test/java/org/opensearch/ad/caching/AbstractCacheTest.java-e b/src/test/java/org/opensearch/ad/caching/AbstractCacheTest.java-e new file mode 100644 index 000000000..5045b45bb --- /dev/null +++ b/src/test/java/org/opensearch/ad/caching/AbstractCacheTest.java-e @@ -0,0 +1,142 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.caching; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Random; + +import org.junit.Before; +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelManager.ModelType; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.ratelimit.CheckpointMaintainWorker; +import org.opensearch.ad.ratelimit.CheckpointWriteWorker; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.model.Entity; + +public class AbstractCacheTest extends AbstractTimeSeriesTest { + protected String modelId1, modelId2, modelId3, modelId4; + protected Entity entity1, entity2, entity3, entity4; + protected ModelState modelState1, modelState2, modelState3, modelState4; + protected String detectorId; + protected AnomalyDetector detector; + protected Clock clock; + protected Duration detectorDuration; + protected float initialPriority; + protected CacheBuffer cacheBuffer; + protected long memoryPerEntity; + protected MemoryTracker memoryTracker; + protected CheckpointWriteWorker checkpointWriteQueue; + protected CheckpointMaintainWorker checkpointMaintainQueue; + protected Random random; + protected int shingleSize; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + detector = mock(AnomalyDetector.class); + detectorId = "123"; + when(detector.getId()).thenReturn(detectorId); + detectorDuration = Duration.ofMinutes(5); + when(detector.getIntervalDuration()).thenReturn(detectorDuration); + when(detector.getIntervalInSeconds()).thenReturn(detectorDuration.getSeconds()); + when(detector.getEnabledFeatureIds()).thenReturn(new ArrayList() { + { + add("a"); + add("b"); + add("c"); + } + }); + shingleSize = 4; + when(detector.getShingleSize()).thenReturn(shingleSize); + + entity1 = Entity.createSingleAttributeEntity("attributeName1", "attributeVal1"); + entity2 = Entity.createSingleAttributeEntity("attributeName1", "attributeVal2"); + entity3 = Entity.createSingleAttributeEntity("attributeName1", "attributeVal3"); + entity4 = Entity.createSingleAttributeEntity("attributeName1", "attributeVal4"); + modelId1 = entity1.getModelId(detectorId).get(); + modelId2 = entity2.getModelId(detectorId).get(); + modelId3 = entity3.getModelId(detectorId).get(); + modelId4 = entity4.getModelId(detectorId).get(); + + clock = mock(Clock.class); + when(clock.instant()).thenReturn(Instant.now()); + + memoryPerEntity = 81920; + memoryTracker = mock(MemoryTracker.class); + + checkpointWriteQueue = mock(CheckpointWriteWorker.class); + checkpointMaintainQueue = mock(CheckpointMaintainWorker.class); + + cacheBuffer = new CacheBuffer( + 1, + 1, + memoryPerEntity, + memoryTracker, + clock, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + detectorId, + checkpointWriteQueue, + checkpointMaintainQueue, + Duration.ofHours(12).toHoursPart() + ); + + initialPriority = cacheBuffer.getPriorityTracker().getUpdatedPriority(0); + + modelState1 = new ModelState<>( + new EntityModel(entity1, new ArrayDeque<>(), null), + modelId1, + detectorId, + ModelType.ENTITY.getName(), + clock, + 0 + ); + + modelState2 = new ModelState<>( + new EntityModel(entity2, new ArrayDeque<>(), null), + modelId2, + detectorId, + ModelType.ENTITY.getName(), + clock, + 0 + ); + + modelState3 = new ModelState<>( + new EntityModel(entity3, new ArrayDeque<>(), null), + modelId3, + detectorId, + ModelType.ENTITY.getName(), + clock, + 0 + ); + + modelState4 = new ModelState<>( + new EntityModel(entity4, new ArrayDeque<>(), null), + modelId4, + detectorId, + ModelType.ENTITY.getName(), + clock, + 0 + ); + } +} diff --git a/src/test/java/org/opensearch/ad/caching/CacheBufferTests.java-e b/src/test/java/org/opensearch/ad/caching/CacheBufferTests.java-e new file mode 100644 index 000000000..7332edf4b --- /dev/null +++ b/src/test/java/org/opensearch/ad/caching/CacheBufferTests.java-e @@ -0,0 +1,179 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.caching; + +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Map.Entry; +import java.util.Optional; + +import org.mockito.ArgumentCaptor; +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.ratelimit.CheckpointMaintainRequest; + +import test.org.opensearch.ad.util.MLUtil; +import test.org.opensearch.ad.util.RandomModelStateConfig; + +public class CacheBufferTests extends AbstractCacheTest { + + // cache.put(1, 1); + // cache.put(2, 2); + // cache.get(1); // returns 1 + // cache.put(3, 3); // evicts key 2 + // cache.get(2); // returns -1 (not found) + // cache.get(3); // returns 3. + // cache.put(4, 4); // evicts key 1. + // cache.get(1); // returns -1 (not found) + // cache.get(3); // returns 3 + // cache.get(4); // returns 4 + public void testRemovalCandidate() { + cacheBuffer.put(modelId1, modelState1); + cacheBuffer.put(modelId2, modelState2); + assertEquals(modelId1, cacheBuffer.get(modelId1).getModelId()); + Optional> removalCandidate = cacheBuffer.getPriorityTracker().getMinimumScaledPriority(); + assertEquals(modelId2, removalCandidate.get().getKey()); + cacheBuffer.remove(); + cacheBuffer.put(modelId3, modelState3); + assertEquals(null, cacheBuffer.get(modelId2)); + assertEquals(modelId3, cacheBuffer.get(modelId3).getModelId()); + removalCandidate = cacheBuffer.getPriorityTracker().getMinimumScaledPriority(); + assertEquals(modelId1, removalCandidate.get().getKey()); + cacheBuffer.remove(modelId1); + assertEquals(null, cacheBuffer.get(modelId1)); + cacheBuffer.put(modelId4, modelState4); + assertEquals(modelId3, cacheBuffer.get(modelId3).getModelId()); + assertEquals(modelId4, cacheBuffer.get(modelId4).getModelId()); + } + + // cache.put(3, 3); + // cache.put(2, 2); + // cache.put(2, 2); + // cache.put(4, 4); + // cache.get(2) => returns 2 + public void testRemovalCandidate2() throws InterruptedException { + cacheBuffer.put(modelId3, modelState3); + cacheBuffer.put(modelId2, modelState2); + cacheBuffer.put(modelId2, modelState2); + cacheBuffer.put(modelId4, modelState4); + assertTrue(cacheBuffer.getModel(modelId2).isPresent()); + + ArgumentCaptor memoryReleased = ArgumentCaptor.forClass(Long.class); + ArgumentCaptor reserved = ArgumentCaptor.forClass(Boolean.class); + ArgumentCaptor orign = ArgumentCaptor.forClass(MemoryTracker.Origin.class); + cacheBuffer.clear(); + verify(memoryTracker, times(2)).releaseMemory(memoryReleased.capture(), reserved.capture(), orign.capture()); + + List capturedMemoryReleased = memoryReleased.getAllValues(); + List capturedreserved = reserved.getAllValues(); + List capturedOrigin = orign.getAllValues(); + assertEquals(3 * memoryPerEntity, capturedMemoryReleased.stream().reduce(0L, (a, b) -> a + b).intValue()); + assertTrue(capturedreserved.get(0)); + assertTrue(!capturedreserved.get(1)); + assertEquals(MemoryTracker.Origin.HC_DETECTOR, capturedOrigin.get(0)); + + assertTrue(!cacheBuffer.expired(Duration.ofHours(1))); + } + + public void testCanRemove() { + String modelId1 = "1"; + String modelId2 = "2"; + String modelId3 = "3"; + assertTrue(cacheBuffer.dedicatedCacheAvailable()); + assertTrue(!cacheBuffer.canReplaceWithinDetector(100)); + + cacheBuffer.put(modelId1, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); + assertTrue(cacheBuffer.canReplaceWithinDetector(100)); + assertTrue(!cacheBuffer.dedicatedCacheAvailable()); + assertTrue(!cacheBuffer.canRemove()); + cacheBuffer.put(modelId2, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); + assertTrue(cacheBuffer.canRemove()); + cacheBuffer.replace(modelId3, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); + assertTrue(cacheBuffer.isActive(modelId2)); + assertTrue(cacheBuffer.isActive(modelId3)); + assertEquals(modelId3, cacheBuffer.getPriorityTracker().getHighestPriorityEntityId().get()); + assertEquals(2, cacheBuffer.getActiveEntities()); + } + + public void testMaintenance() { + String modelId1 = "1"; + String modelId2 = "2"; + String modelId3 = "3"; + cacheBuffer.put(modelId1, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); + cacheBuffer.put(modelId2, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); + cacheBuffer.put(modelId3, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); + cacheBuffer.maintenance(); + assertEquals(3, cacheBuffer.getActiveEntities()); + assertEquals(3, cacheBuffer.getAllModels().size()); + // the year of 2122, 100 years later to simulate we are gonna remove all cached entries + when(clock.instant()).thenReturn(Instant.ofEpochSecond(4814540761L)); + cacheBuffer.maintenance(); + assertEquals(0, cacheBuffer.getActiveEntities()); + } + + @SuppressWarnings("unchecked") + public void testMaintainByHourNothingToSave() { + // hash code 49 % 6 = 1 + String modelId1 = "1"; + // hash code 50 % 6 = 2 + String modelId2 = "2"; + // hash code 51 % 6 = 3 + String modelId3 = "3"; + // hour 17. 17 % 6 (check point frequency) = 5 + when(clock.instant()).thenReturn(Instant.ofEpochSecond(1658854904L)); + cacheBuffer.put(modelId1, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); + cacheBuffer.put(modelId2, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); + cacheBuffer.put(modelId3, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); + + ArgumentCaptor> savedStates = ArgumentCaptor.forClass(List.class); + cacheBuffer.maintenance(); + verify(checkpointMaintainQueue, times(1)).putAll(savedStates.capture()); + assertTrue(savedStates.getValue().isEmpty()); + + // hour 13. 13 % 6 (check point frequency) = 1 + when(clock.instant()).thenReturn(Instant.ofEpochSecond(1658928080L)); + + } + + @SuppressWarnings("unchecked") + public void testMaintainByHourSaveOne() { + // hash code 49 % 6 = 1 + String modelId1 = "1"; + // hash code 50 % 6 = 2 + String modelId2 = "2"; + // hash code 51 % 6 = 3 + String modelId3 = "3"; + // hour 13. 13 % 6 (check point frequency) = 1 + when(clock.instant()).thenReturn(Instant.ofEpochSecond(1658928080L)); + cacheBuffer.put(modelId1, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); + cacheBuffer.put(modelId2, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); + cacheBuffer.put(modelId3, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); + + ArgumentCaptor> savedStates = ArgumentCaptor.forClass(List.class); + cacheBuffer.maintenance(); + verify(checkpointMaintainQueue, times(1)).putAll(savedStates.capture()); + List toSave = savedStates.getValue(); + assertEquals(1, toSave.size()); + assertEquals(modelId1, toSave.get(0).getEntityModelId()); + } + + /** + * Test that if we remove a non-existent key, there is no exception + */ + public void testRemovedNull() { + assertEquals(null, cacheBuffer.remove("foo")); + } +} diff --git a/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java-e b/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java-e new file mode 100644 index 000000000..7774fb314 --- /dev/null +++ b/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java-e @@ -0,0 +1,733 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.caching; + +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyBoolean; +import static org.mockito.Mockito.anyDouble; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.ml.CheckpointDao; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ModelManager.ModelType; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.ByteSizeValue; +import org.opensearch.monitor.jvm.JvmInfo; +import org.opensearch.monitor.jvm.JvmInfo.Mem; +import org.opensearch.monitor.jvm.JvmService; +import org.opensearch.threadpool.Scheduler.ScheduledCancellable; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.model.Entity; + +public class PriorityCacheTests extends AbstractCacheTest { + private static final Logger LOG = LogManager.getLogger(PriorityCacheTests.class); + + EntityCache entityCache; + CheckpointDao checkpoint; + ModelManager modelManager; + + ClusterService clusterService; + Settings settings; + String detectorId2; + AnomalyDetector detector2; + double[] point; + int dedicatedCacheSize; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + + checkpoint = mock(CheckpointDao.class); + + modelManager = mock(ModelManager.class); + + clusterService = mock(ClusterService.class); + ClusterSettings settings = new ClusterSettings( + Settings.EMPTY, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays + .asList( + AnomalyDetectorSettings.DEDICATED_CACHE_SIZE, + AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE, + AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE, + AnomalyDetectorSettings.CHECKPOINT_TTL, + AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ + ) + ) + ) + ); + when(clusterService.getClusterSettings()).thenReturn(settings); + + dedicatedCacheSize = 1; + + threadPool = mock(ThreadPool.class); + setUpADThreadPool(threadPool); + + EntityCache cache = new PriorityCache( + checkpoint, + dedicatedCacheSize, + AnomalyDetectorSettings.CHECKPOINT_TTL, + AnomalyDetectorSettings.MAX_INACTIVE_ENTITIES, + memoryTracker, + AnomalyDetectorSettings.NUM_TREES, + clock, + clusterService, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + threadPool, + checkpointWriteQueue, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + checkpointMaintainQueue, + Settings.EMPTY, + AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ + ); + + CacheProvider cacheProvider = new CacheProvider(); + cacheProvider.set(cache); + entityCache = cacheProvider.get(); + + when(memoryTracker.estimateTRCFModelSize(anyInt(), anyInt(), anyDouble(), anyInt(), anyBoolean())).thenReturn(memoryPerEntity); + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); + + detector2 = mock(AnomalyDetector.class); + detectorId2 = "456"; + when(detector2.getId()).thenReturn(detectorId2); + when(detector2.getIntervalDuration()).thenReturn(detectorDuration); + when(detector2.getIntervalInSeconds()).thenReturn(detectorDuration.getSeconds()); + + point = new double[] { 0.1 }; + } + + public void testCacheHit() { + // 800 MB is the limit + long largeHeapSize = 800_000_000; + JvmInfo info = mock(JvmInfo.class); + Mem mem = mock(Mem.class); + when(info.getMem()).thenReturn(mem); + when(mem.getHeapMax()).thenReturn(new ByteSizeValue(largeHeapSize)); + JvmService jvmService = mock(JvmService.class); + when(jvmService.info()).thenReturn(info); + + // ClusterService clusterService = mock(ClusterService.class); + float modelMaxPercen = 0.1f; + // Settings settings = Settings.builder().put(AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE.getKey(), modelMaxPercen).build(); + // ClusterSettings clusterSettings = new ClusterSettings( + // settings, + // Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE))) + // ); + // when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + memoryTracker = spy(new MemoryTracker(jvmService, modelMaxPercen, 0.002, clusterService, mock(ADCircuitBreakerService.class))); + + EntityCache cache = new PriorityCache( + checkpoint, + dedicatedCacheSize, + AnomalyDetectorSettings.CHECKPOINT_TTL, + AnomalyDetectorSettings.MAX_INACTIVE_ENTITIES, + memoryTracker, + AnomalyDetectorSettings.NUM_TREES, + clock, + clusterService, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + threadPool, + checkpointWriteQueue, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + checkpointMaintainQueue, + Settings.EMPTY, + AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ + ); + + CacheProvider cacheProvider = new CacheProvider(); + cacheProvider.set(cache); + entityCache = cacheProvider.get(); + + // cache miss due to door keeper + assertEquals(null, entityCache.get(modelState1.getModelId(), detector)); + // cache miss due to empty cache + assertEquals(null, entityCache.get(modelState1.getModelId(), detector)); + entityCache.hostIfPossible(detector, modelState1); + assertEquals(1, entityCache.getTotalActiveEntities()); + assertEquals(1, entityCache.getAllModels().size()); + ModelState hitState = entityCache.get(modelState1.getModelId(), detector); + assertEquals(detectorId, hitState.getId()); + EntityModel model = hitState.getModel(); + assertEquals(false, model.getTrcf().isPresent()); + assertTrue(model.getSamples().isEmpty()); + modelState1.getModel().addSample(point); + assertTrue(Arrays.equals(point, model.getSamples().peek())); + + ArgumentCaptor memoryConsumed = ArgumentCaptor.forClass(Long.class); + ArgumentCaptor reserved = ArgumentCaptor.forClass(Boolean.class); + ArgumentCaptor origin = ArgumentCaptor.forClass(MemoryTracker.Origin.class); + + // input dimension: 3, shingle: 4 + long expectedMemoryPerEntity = 436828L; + verify(memoryTracker, times(1)).consumeMemory(memoryConsumed.capture(), reserved.capture(), origin.capture()); + assertEquals(dedicatedCacheSize * expectedMemoryPerEntity, memoryConsumed.getValue().intValue()); + assertEquals(true, reserved.getValue().booleanValue()); + assertEquals(MemoryTracker.Origin.HC_DETECTOR, origin.getValue()); + + // for (int i = 0; i < 2; i++) { + // cacheProvider.get(modelId2, detector); + // } + } + + public void testInActiveCache() { + // make modelId1 has enough priority + for (int i = 0; i < 10; i++) { + entityCache.get(modelId1, detector); + } + assertTrue(entityCache.hostIfPossible(detector, modelState1)); + assertEquals(1, entityCache.getActiveEntities(detectorId)); + when(memoryTracker.canAllocate(anyLong())).thenReturn(false); + for (int i = 0; i < 2; i++) { + assertEquals(null, entityCache.get(modelId2, detector)); + } + assertTrue(false == entityCache.hostIfPossible(detector, modelState2)); + // modelId2 gets put to inactive cache due to nothing in shared cache + // and it cannot replace modelId1 + assertEquals(1, entityCache.getActiveEntities(detectorId)); + } + + public void testSharedCache() { + // make modelId1 has enough priority + for (int i = 0; i < 10; i++) { + entityCache.get(modelId1, detector); + } + entityCache.hostIfPossible(detector, modelState1); + assertEquals(1, entityCache.getActiveEntities(detectorId)); + when(memoryTracker.canAllocate(anyLong())).thenReturn(true); + for (int i = 0; i < 2; i++) { + entityCache.get(modelId2, detector); + } + entityCache.hostIfPossible(detector, modelState2); + // modelId2 should be in shared cache + assertEquals(2, entityCache.getActiveEntities(detectorId)); + + for (int i = 0; i < 10; i++) { + entityCache.get(modelId3, detector2); + } + modelState3 = new ModelState<>( + new EntityModel(entity3, new ArrayDeque<>(), null), + modelId3, + detectorId2, + ModelType.ENTITY.getName(), + clock, + 0 + ); + + entityCache.hostIfPossible(detector2, modelState3); + assertEquals(1, entityCache.getActiveEntities(detectorId2)); + when(memoryTracker.canAllocate(anyLong())).thenReturn(false); + for (int i = 0; i < 4; i++) { + // replace modelId2 in shared cache + entityCache.get(modelId4, detector2); + } + modelState4 = new ModelState<>( + new EntityModel(entity4, new ArrayDeque<>(), null), + modelId4, + detectorId2, + ModelType.ENTITY.getName(), + clock, + 0 + ); + entityCache.hostIfPossible(detector2, modelState4); + assertEquals(2, entityCache.getActiveEntities(detectorId2)); + assertEquals(3, entityCache.getTotalActiveEntities()); + assertEquals(3, entityCache.getAllModels().size()); + + when(memoryTracker.memoryToShed()).thenReturn(memoryPerEntity); + entityCache.maintenance(); + assertEquals(2, entityCache.getTotalActiveEntities()); + assertEquals(2, entityCache.getAllModels().size()); + assertEquals(1, entityCache.getActiveEntities(detectorId2)); + } + + public void testReplace() { + for (int i = 0; i < 2; i++) { + entityCache.get(modelState1.getModelId(), detector); + } + + entityCache.hostIfPossible(detector, modelState1); + assertEquals(1, entityCache.getActiveEntities(detectorId)); + when(memoryTracker.canAllocate(anyLong())).thenReturn(false); + ModelState state = null; + + for (int i = 0; i < 4; i++) { + entityCache.get(modelId2, detector); + } + + // emptyState2 replaced emptyState2 + entityCache.hostIfPossible(detector, modelState2); + state = entityCache.get(modelId2, detector); + + assertEquals(modelId2, state.getModelId()); + assertEquals(1, entityCache.getActiveEntities(detectorId)); + } + + public void testCannotAllocateBuffer() { + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(false); + expectThrows(LimitExceededException.class, () -> entityCache.get(modelId1, detector)); + } + + public void testExpiredCacheBuffer() { + when(clock.instant()).thenReturn(Instant.MIN); + when(memoryTracker.canAllocate(anyLong())).thenReturn(true); + for (int i = 0; i < 3; i++) { + entityCache.get(modelId1, detector); + } + for (int i = 0; i < 3; i++) { + entityCache.get(modelId2, detector); + } + + entityCache.hostIfPossible(detector, modelState1); + entityCache.hostIfPossible(detector, modelState2); + + assertEquals(2, entityCache.getTotalActiveEntities()); + assertEquals(2, entityCache.getAllModels().size()); + when(clock.instant()).thenReturn(Instant.now()); + entityCache.maintenance(); + assertEquals(0, entityCache.getTotalActiveEntities()); + assertEquals(0, entityCache.getAllModels().size()); + + for (int i = 0; i < 2; i++) { + // doorkeeper should have been reset + assertEquals(null, entityCache.get(modelId2, detector)); + } + } + + public void testClear() { + when(memoryTracker.canAllocate(anyLong())).thenReturn(true); + + for (int i = 0; i < 3; i++) { + // make modelId1 have higher priority + entityCache.get(modelId1, detector); + } + + for (int i = 0; i < 2; i++) { + entityCache.get(modelId2, detector); + } + + entityCache.hostIfPossible(detector, modelState1); + entityCache.hostIfPossible(detector, modelState2); + + assertEquals(2, entityCache.getTotalActiveEntities()); + assertTrue(entityCache.isActive(detectorId, modelId1)); + assertEquals(0, entityCache.getTotalUpdates(detectorId)); + modelState1.getModel().addSample(point); + assertEquals(1, entityCache.getTotalUpdates(detectorId)); + assertEquals(1, entityCache.getTotalUpdates(detectorId, modelId1)); + entityCache.clear(detectorId); + assertEquals(0, entityCache.getTotalActiveEntities()); + + for (int i = 0; i < 2; i++) { + // doorkeeper should have been reset + assertEquals(null, entityCache.get(modelId2, detector)); + } + } + + class CleanRunnable implements Runnable { + @Override + public void run() { + entityCache.maintenance(); + } + } + + private void setUpConcurrentMaintenance() { + when(memoryTracker.canAllocate(anyLong())).thenReturn(true); + for (int i = 0; i < 2; i++) { + entityCache.get(modelId1, detector); + } + for (int i = 0; i < 2; i++) { + entityCache.get(modelId2, detector); + } + for (int i = 0; i < 2; i++) { + entityCache.get(modelId3, detector); + } + + entityCache.hostIfPossible(detector, modelState1); + entityCache.hostIfPossible(detector, modelState2); + entityCache.hostIfPossible(detector, modelState3); + + when(memoryTracker.memoryToShed()).thenReturn(memoryPerEntity); + assertEquals(3, entityCache.getTotalActiveEntities()); + } + + public void testSuccessfulConcurrentMaintenance() { + setUpConcurrentMaintenance(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + doAnswer(invovacation -> { + inProgressLatch.await(100, TimeUnit.SECONDS); + return null; + }).when(memoryTracker).releaseMemory(anyLong(), anyBoolean(), any(MemoryTracker.Origin.class)); + + doAnswer(invocation -> { + inProgressLatch.countDown(); + return mock(ScheduledCancellable.class); + }).when(threadPool).schedule(any(), any(), any()); + + // both maintenance call will be blocked until schedule gets called + new Thread(new CleanRunnable()).start(); + + entityCache.maintenance(); + + verify(threadPool, times(1)).schedule(any(), any(), any()); + } + + class FailedCleanRunnable implements Runnable { + CountDownLatch singalThreadToStart; + + FailedCleanRunnable(CountDownLatch countDown) { + this.singalThreadToStart = countDown; + } + + @Override + public void run() { + try { + entityCache.maintenance(); + } catch (Exception e) { + // maintenance can throw AnomalyDetectionException, catch it here + singalThreadToStart.countDown(); + } + } + } + + public void testFailedConcurrentMaintenance() throws InterruptedException { + setUpConcurrentMaintenance(); + final CountDownLatch scheduleCountDown = new CountDownLatch(1); + final CountDownLatch scheduledThreadCountDown = new CountDownLatch(1); + + doThrow(NullPointerException.class).when(memoryTracker).releaseMemory(anyLong(), anyBoolean(), any(MemoryTracker.Origin.class)); + + doAnswer(invovacation -> { + scheduleCountDown.await(100, TimeUnit.SECONDS); + return null; + }).when(memoryTracker).syncMemoryState(any(MemoryTracker.Origin.class), anyLong(), anyLong()); + + AtomicReference runnable = new AtomicReference(); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + runnable.set((Runnable) args[0]); + scheduleCountDown.countDown(); + return mock(ScheduledCancellable.class); + }).when(threadPool).schedule(any(), any(), any()); + + try { + // both maintenance call will be blocked until schedule gets called + new Thread(new FailedCleanRunnable(scheduledThreadCountDown)).start(); + + entityCache.maintenance(); + } catch (TimeSeriesException e) { + scheduledThreadCountDown.countDown(); + } + + scheduledThreadCountDown.await(100, TimeUnit.SECONDS); + + // first thread finishes and throw exception + assertTrue(runnable.get() != null); + try { + // invoke second thread's runnable object + runnable.get().run(); + } catch (Exception e2) { + // runnable will log a line and return. It won't cause any exception. + assertTrue(false); + return; + } + // we should return here + return; + } + + private void selectTestCommon(int entityFreq) { + for (int i = 0; i < entityFreq; i++) { + // bypass doorkeeper + entityCache.get(entity1.getModelId(detectorId).get(), detector); + } + Collection cacheMissEntities = new ArrayList<>(); + cacheMissEntities.add(entity1); + Pair, List> selectedAndOther = entityCache.selectUpdateCandidate(cacheMissEntities, detectorId, detector); + List selected = selectedAndOther.getLeft(); + assertEquals(1, selected.size()); + assertEquals(entity1, selected.get(0)); + assertEquals(0, selectedAndOther.getRight().size()); + } + + public void testSelectToDedicatedCache() { + selectTestCommon(2); + } + + public void testSelectToSharedCache() { + for (int i = 0; i < 2; i++) { + // bypass doorkeeper + entityCache.get(entity2.getModelId(detectorId).get(), detector); + } + when(memoryTracker.canAllocate(anyLong())).thenReturn(true); + + // fill in dedicated cache + entityCache.hostIfPossible(detector, modelState2); + selectTestCommon(2); + verify(memoryTracker, times(1)).canAllocate(anyLong()); + } + + public void testSelectToReplaceInCache() { + for (int i = 0; i < 2; i++) { + // bypass doorkeeper + entityCache.get(entity2.getModelId(detectorId).get(), detector); + } + when(memoryTracker.canAllocate(anyLong())).thenReturn(false); + + // fill in dedicated cache + entityCache.hostIfPossible(detector, modelState2); + // make entity1 have enough priority to replace entity2 + selectTestCommon(10); + verify(memoryTracker, times(1)).canAllocate(anyLong()); + } + + private void replaceInOtherCacheSetUp() { + Entity entity5 = Entity.createSingleAttributeEntity("attributeName1", "attributeVal5"); + Entity entity6 = Entity.createSingleAttributeEntity("attributeName1", "attributeVal6"); + ModelState modelState5 = new ModelState<>( + new EntityModel(entity5, new ArrayDeque<>(), null), + entity5.getModelId(detectorId2).get(), + detectorId2, + ModelType.ENTITY.getName(), + clock, + 0 + ); + ModelState modelState6 = new ModelState<>( + new EntityModel(entity6, new ArrayDeque<>(), null), + entity6.getModelId(detectorId2).get(), + detectorId2, + ModelType.ENTITY.getName(), + clock, + 0 + ); + + for (int i = 0; i < 3; i++) { + // bypass doorkeeper and leave room for lower frequency entity in testSelectToCold + entityCache.get(entity5.getModelId(detectorId2).get(), detector2); + entityCache.get(entity6.getModelId(detectorId2).get(), detector2); + } + for (int i = 0; i < 10; i++) { + // entity1 cannot replace entity2 due to frequency + entityCache.get(entity2.getModelId(detectorId).get(), detector); + } + // put modelState5 in dedicated and modelState6 in shared cache + when(memoryTracker.canAllocate(anyLong())).thenReturn(true); + entityCache.hostIfPossible(detector2, modelState5); + entityCache.hostIfPossible(detector2, modelState6); + + // fill in dedicated cache + entityCache.hostIfPossible(detector, modelState2); + + // don't allow to use shared cache afterwards + when(memoryTracker.canAllocate(anyLong())).thenReturn(false); + } + + public void testSelectToReplaceInOtherCache() { + replaceInOtherCacheSetUp(); + + // make entity1 have enough priority to replace entity2 + selectTestCommon(10); + // once when deciding whether to host modelState6; + // once when calling selectUpdateCandidate on entity1 + verify(memoryTracker, times(2)).canAllocate(anyLong()); + } + + public void testSelectToCold() { + replaceInOtherCacheSetUp(); + + for (int i = 0; i < 2; i++) { + // bypass doorkeeper + entityCache.get(entity1.getModelId(detectorId).get(), detector); + } + Collection cacheMissEntities = new ArrayList<>(); + cacheMissEntities.add(entity1); + Pair, List> selectedAndOther = entityCache.selectUpdateCandidate(cacheMissEntities, detectorId, detector); + List cold = selectedAndOther.getRight(); + assertEquals(1, cold.size()); + assertEquals(entity1, cold.get(0)); + assertEquals(0, selectedAndOther.getLeft().size()); + } + + /* + * Test the scenario: + * 1. A detector's buffer uses dedicated and shared memory + * 2. a new detector's buffer is created and triggers clearMemory (every new + * CacheBuffer creation will trigger it) + * 3. clearMemory found we can reclaim shared memory + */ + public void testClearMemory() { + for (int i = 0; i < 2; i++) { + // bypass doorkeeper + entityCache.get(entity2.getModelId(detectorId).get(), detector); + } + + for (int i = 0; i < 10; i++) { + // bypass doorkeeper and make entity1 have higher frequency + entityCache.get(entity1.getModelId(detectorId).get(), detector); + } + + // put modelState5 in dedicated and modelState6 in shared cache + when(memoryTracker.canAllocate(anyLong())).thenReturn(true); + entityCache.hostIfPossible(detector, modelState1); + entityCache.hostIfPossible(detector, modelState2); + + // two entities get inserted to cache + assertTrue(null != entityCache.get(entity1.getModelId(detectorId).get(), detector)); + assertTrue(null != entityCache.get(entity2.getModelId(detectorId).get(), detector)); + + Entity entity5 = Entity.createSingleAttributeEntity("attributeName1", "attributeVal5"); + when(memoryTracker.memoryToShed()).thenReturn(memoryPerEntity); + for (int i = 0; i < 2; i++) { + // bypass doorkeeper, CacheBuffer created, and trigger clearMemory + entityCache.get(entity5.getModelId(detectorId2).get(), detector2); + } + + assertTrue(null != entityCache.get(entity1.getModelId(detectorId).get(), detector)); + // entity 2 removed + assertTrue(null == entityCache.get(entity2.getModelId(detectorId).get(), detector)); + assertTrue(null == entityCache.get(entity5.getModelId(detectorId2).get(), detector)); + } + + public void testSelectEmpty() { + Collection cacheMissEntities = new ArrayList<>(); + cacheMissEntities.add(entity1); + Pair, List> selectedAndOther = entityCache.selectUpdateCandidate(cacheMissEntities, detectorId, detector); + assertEquals(0, selectedAndOther.getLeft().size()); + assertEquals(0, selectedAndOther.getRight().size()); + } + + // test that detector interval is more than 1 hour that maintenance is called before + // the next get method + public void testLongDetectorInterval() { + try { + ADEnabledSetting.getInstance().setSettingValue(ADEnabledSetting.DOOR_KEEPER_IN_CACHE_ENABLED, true); + when(clock.instant()).thenReturn(Instant.ofEpochSecond(1000)); + when(detector.getIntervalDuration()).thenReturn(Duration.ofHours(12)); + String modelId = entity1.getModelId(detectorId).get(); + // record last access time 1000 + assertTrue(null == entityCache.get(modelId, detector)); + assertEquals(-1, entityCache.getLastActiveMs(detectorId, modelId)); + // 2 hour = 7200 seconds have passed + long currentTimeEpoch = 8200; + when(clock.instant()).thenReturn(Instant.ofEpochSecond(currentTimeEpoch)); + // door keeper should not be expired since we reclaim space every 60 intervals + entityCache.maintenance(); + // door keeper still has the record and won't blocks entity state being created + entityCache.get(modelId, detector); + // * 1000 to convert to milliseconds + assertEquals(currentTimeEpoch * 1000, entityCache.getLastActiveMs(detectorId, modelId)); + } finally { + ADEnabledSetting.getInstance().setSettingValue(ADEnabledSetting.DOOR_KEEPER_IN_CACHE_ENABLED, false); + } + } + + public void testGetNoPriorityUpdate() { + for (int i = 0; i < 3; i++) { + // bypass doorkeeper + entityCache.get(entity2.getModelId(detectorId).get(), detector); + } + + // fill in dedicated cache + entityCache.hostIfPossible(detector, modelState2); + + // don't allow to use shared cache afterwards + when(memoryTracker.canAllocate(anyLong())).thenReturn(false); + + for (int i = 0; i < 2; i++) { + // bypass doorkeeper + entityCache.get(entity1.getModelId(detectorId).get(), detector); + } + for (int i = 0; i < 10; i++) { + // won't increase frequency + entityCache.getForMaintainance(detectorId, entity1.getModelId(detectorId).get()); + } + + entityCache.hostIfPossible(detector, modelState1); + + // entity1 does not replace entity2 + assertTrue(null == entityCache.get(entity1.getModelId(detectorId).get(), detector)); + assertTrue(null != entityCache.get(entity2.getModelId(detectorId).get(), detector)); + + for (int i = 0; i < 10; i++) { + // increase frequency + entityCache.get(entity1.getModelId(detectorId).get(), detector); + } + + entityCache.hostIfPossible(detector, modelState1); + + // entity1 replace entity2 + assertTrue(null != entityCache.get(entity1.getModelId(detectorId).get(), detector)); + assertTrue(null == entityCache.get(entity2.getModelId(detectorId).get(), detector)); + } + + public void testRemoveEntityModel() { + for (int i = 0; i < 3; i++) { + // bypass doorkeeper + entityCache.get(entity2.getModelId(detectorId).get(), detector); + } + + // fill in dedicated cache + entityCache.hostIfPossible(detector, modelState2); + + assertTrue(null != entityCache.get(entity2.getModelId(detectorId).get(), detector)); + + entityCache.removeEntityModel(detectorId, entity2.getModelId(detectorId).get()); + + assertTrue(null == entityCache.get(entity2.getModelId(detectorId).get(), detector)); + + verify(checkpoint, times(1)).deleteModelCheckpoint(eq(entity2.getModelId(detectorId).get()), any()); + verify(checkpointWriteQueue, never()).write(any(), anyBoolean(), any()); + } +} diff --git a/src/test/java/org/opensearch/ad/caching/PriorityTrackerTests.java-e b/src/test/java/org/opensearch/ad/caching/PriorityTrackerTests.java-e new file mode 100644 index 000000000..4e721d68e --- /dev/null +++ b/src/test/java/org/opensearch/ad/caching/PriorityTrackerTests.java-e @@ -0,0 +1,100 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.caching; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.time.Clock; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.List; + +import org.junit.Before; +import org.opensearch.test.OpenSearchTestCase; + +public class PriorityTrackerTests extends OpenSearchTestCase { + Clock clock; + PriorityTracker tracker; + Instant now; + String entity1, entity2, entity3; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + clock = mock(Clock.class); + now = Instant.now(); + tracker = new PriorityTracker(clock, 1, now.getEpochSecond(), 3); + entity1 = "entity1"; + entity2 = "entity2"; + entity3 = "entity3"; + } + + public void testNormal() { + when(clock.instant()).thenReturn(now); + // first interval entity 1 and 3 + tracker.updatePriority(entity1); + tracker.updatePriority(entity3); + when(clock.instant()).thenReturn(now.plusSeconds(60L)); + // second interval entity 1 and 2 + tracker.updatePriority(entity1); + tracker.updatePriority(entity2); + // we should have entity 1, 2, 3 in order. 2 comes before 3 because it happens later + List top3 = tracker.getTopNEntities(3); + assertEquals(entity1, top3.get(0)); + assertEquals(entity2, top3.get(1)); + assertEquals(entity3, top3.get(2)); + + // even though I want top 4, but there are only 3 entities + List top4 = tracker.getTopNEntities(4); + assertEquals(3, top4.size()); + assertEquals(entity1, top3.get(0)); + assertEquals(entity2, top3.get(1)); + assertEquals(entity3, top3.get(2)); + } + + public void testOverflow() { + when(clock.instant()).thenReturn(now); + tracker.updatePriority(entity1); + float priority1 = tracker.getMinimumScaledPriority().get().getValue(); + + // when(clock.instant()).thenReturn(now.plusSeconds(60L)); + tracker.updatePriority(entity1); + float priority2 = tracker.getMinimumScaledPriority().get().getValue(); + // we incremented the priority + assertTrue("The following is expected: " + priority2 + " > " + priority1, priority2 > priority1); + + when(clock.instant()).thenReturn(now.plus(3, ChronoUnit.DAYS)); + tracker.updatePriority(entity1); + // overflow happens, we use increment as the new priority + assertEquals(0, tracker.getMinimumScaledPriority().get().getValue().floatValue(), 0.001); + } + + public void testTooManyEntities() { + when(clock.instant()).thenReturn(now); + tracker = new PriorityTracker(clock, 1, now.getEpochSecond(), 2); + tracker.updatePriority(entity1); + tracker.updatePriority(entity3); + assertEquals(2, tracker.size()); + tracker.updatePriority(entity2); + // one entity is kicked out due to the size limit is reached. + assertEquals(2, tracker.size()); + } + + public void testEmptyTracker() { + assertTrue(!tracker.getMinimumScaledPriority().isPresent()); + assertTrue(!tracker.getMinimumPriority().isPresent()); + assertTrue(!tracker.getMinimumPriorityEntityId().isPresent()); + assertTrue(!tracker.getHighestPriorityEntityId().isPresent()); + } +} diff --git a/src/test/java/org/opensearch/ad/cluster/ADClusterEventListenerTests.java-e b/src/test/java/org/opensearch/ad/cluster/ADClusterEventListenerTests.java-e new file mode 100644 index 000000000..415ec75fe --- /dev/null +++ b/src/test/java/org/opensearch/ad/cluster/ADClusterEventListenerTests.java-e @@ -0,0 +1,209 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.cluster; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.opensearch.cluster.node.DiscoveryNodeRole.BUILT_IN_ROLES; +import static org.opensearch.test.ClusterServiceUtils.createClusterService; + +import java.util.HashMap; +import java.util.Optional; + +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.cluster.ClusterChangedEvent; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.block.ClusterBlocks; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.gateway.GatewayService; +import org.opensearch.timeseries.AbstractTimeSeriesTest; + +public class ADClusterEventListenerTests extends AbstractTimeSeriesTest { + private final String clusterManagerNodeId = "clusterManagerNode"; + private final String dataNode1Id = "dataNode1"; + private final String clusterName = "multi-node-cluster"; + + private ClusterService clusterService; + private ADClusterEventListener listener; + private HashRing hashRing; + private ClusterState oldClusterState; + private ClusterState newClusterState; + private DiscoveryNode clusterManagerNode; + private DiscoveryNode dataNode1; + + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(ADClusterEventListenerTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + super.setUpLog4jForJUnit(ADClusterEventListener.class); + clusterService = createClusterService(threadPool); + hashRing = mock(HashRing.class); + + clusterManagerNode = new DiscoveryNode( + clusterManagerNodeId, + buildNewFakeTransportAddress(), + emptyMap(), + emptySet(), + Version.CURRENT + ); + dataNode1 = new DiscoveryNode(dataNode1Id, buildNewFakeTransportAddress(), emptyMap(), BUILT_IN_ROLES, Version.CURRENT); + oldClusterState = ClusterState + .builder(new ClusterName(clusterName)) + .nodes( + new DiscoveryNodes.Builder() + .clusterManagerNodeId(clusterManagerNodeId) + .localNodeId(clusterManagerNodeId) + .add(clusterManagerNode) + ) + .build(); + newClusterState = ClusterState + .builder(new ClusterName(clusterName)) + .nodes( + new DiscoveryNodes.Builder() + .clusterManagerNodeId(clusterManagerNodeId) + .localNodeId(dataNode1Id) + .add(clusterManagerNode) + .add(dataNode1) + ) + .build(); + + listener = new ADClusterEventListener(clusterService, hashRing); + } + + @Override + @After + public void tearDown() throws Exception { + super.tearDown(); + super.tearDownLog4jForJUnit(); + clusterService = null; + hashRing = null; + oldClusterState = null; + listener = null; + } + + public void testUnchangedClusterState() { + listener.clusterChanged(new ClusterChangedEvent("foo", oldClusterState, oldClusterState)); + assertTrue(!testAppender.containsMessage(ADClusterEventListener.NODE_CHANGED_MSG)); + } + + public void testIsWarmNode() { + HashMap attributesForNode1 = new HashMap<>(); + attributesForNode1.put(ADCommonName.BOX_TYPE_KEY, ADCommonName.WARM_BOX_TYPE); + dataNode1 = new DiscoveryNode(dataNode1Id, buildNewFakeTransportAddress(), attributesForNode1, BUILT_IN_ROLES, Version.CURRENT); + + ClusterState warmNodeClusterState = ClusterState + .builder(new ClusterName(clusterName)) + .nodes( + new DiscoveryNodes.Builder() + .clusterManagerNodeId(clusterManagerNodeId) + .localNodeId(dataNode1Id) + .add(clusterManagerNode) + .add(dataNode1) + ) + .blocks(ClusterBlocks.builder().addGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) + .build(); + listener.clusterChanged(new ClusterChangedEvent("foo", warmNodeClusterState, oldClusterState)); + assertTrue(testAppender.containsMessage(ADClusterEventListener.NOT_RECOVERED_MSG)); + } + + public void testNotRecovered() { + ClusterState blockedClusterState = ClusterState + .builder(new ClusterName(clusterName)) + .nodes( + new DiscoveryNodes.Builder() + .clusterManagerNodeId(clusterManagerNodeId) + .localNodeId(dataNode1Id) + .add(clusterManagerNode) + .add(dataNode1) + ) + .blocks(ClusterBlocks.builder().addGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) + .build(); + listener.clusterChanged(new ClusterChangedEvent("foo", blockedClusterState, oldClusterState)); + assertTrue(testAppender.containsMessage(ADClusterEventListener.NOT_RECOVERED_MSG)); + } + + class ListenerRunnable implements Runnable { + + @Override + public void run() { + listener.clusterChanged(new ClusterChangedEvent("foo", newClusterState, oldClusterState)); + } + } + + public void testInProgress() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + Thread.sleep(1000); + listener.onResponse(true); + return null; + }).when(hashRing).buildCircles(any(), any()); + new Thread(new ListenerRunnable()).start(); + listener.clusterChanged(new ClusterChangedEvent("bar", newClusterState, oldClusterState)); + assertTrue(testAppender.containsMessage(ADClusterEventListener.IN_PROGRESS_MSG)); + } + + public void testNodeAdded() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(true); + return null; + }).when(hashRing).buildCircles(any(), any()); + + doAnswer(invocation -> Optional.of(clusterManagerNode)) + .when(hashRing) + .getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class)); + + listener.clusterChanged(new ClusterChangedEvent("foo", newClusterState, oldClusterState)); + assertTrue(testAppender.containsMessage(ADClusterEventListener.NODE_CHANGED_MSG)); + assertTrue(testAppender.containsMessage("node removed: false, node added: true")); + } + + public void testNodeRemoved() { + ClusterState twoDataNodeClusterState = ClusterState + .builder(new ClusterName(clusterName)) + .nodes( + new DiscoveryNodes.Builder() + .clusterManagerNodeId(clusterManagerNodeId) + .localNodeId(dataNode1Id) + .add(new DiscoveryNode(clusterManagerNodeId, buildNewFakeTransportAddress(), emptyMap(), emptySet(), Version.CURRENT)) + .add(dataNode1) + .add(new DiscoveryNode("dataNode2", buildNewFakeTransportAddress(), emptyMap(), BUILT_IN_ROLES, Version.CURRENT)) + ) + .build(); + + listener.clusterChanged(new ClusterChangedEvent("foo", newClusterState, twoDataNodeClusterState)); + assertTrue(testAppender.containsMessage(ADClusterEventListener.NODE_CHANGED_MSG)); + assertTrue(testAppender.containsMessage("node removed: true, node added: true")); + } +} diff --git a/src/test/java/org/opensearch/ad/cluster/ADDataMigratorTests.java b/src/test/java/org/opensearch/ad/cluster/ADDataMigratorTests.java index 1b2bcde4a..65dcd37fa 100644 --- a/src/test/java/org/opensearch/ad/cluster/ADDataMigratorTests.java +++ b/src/test/java/org/opensearch/ad/cluster/ADDataMigratorTests.java @@ -38,10 +38,10 @@ import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; -import org.opensearch.index.shard.ShardId; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; diff --git a/src/test/java/org/opensearch/ad/cluster/ADDataMigratorTests.java-e b/src/test/java/org/opensearch/ad/cluster/ADDataMigratorTests.java-e new file mode 100644 index 000000000..4ebdeb257 --- /dev/null +++ b/src/test/java/org/opensearch/ad/cluster/ADDataMigratorTests.java-e @@ -0,0 +1,469 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.cluster; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ad.constant.ADCommonName.DETECTION_STATE_INDEX; + +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.NoShardAvailableActionException; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.ad.ADUnitTestCase; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonName; + +public class ADDataMigratorTests extends ADUnitTestCase { + private Client client; + private ClusterService clusterService; + private NamedXContentRegistry namedXContentRegistry; + private ADIndexManagement detectionIndices; + private ADDataMigrator adDataMigrator; + private String detectorId; + private String taskId; + private String detectorContent; + private String jobContent; + private String indexResponseContent; + private String internalError; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + client = mock(Client.class); + clusterService = mock(ClusterService.class); + namedXContentRegistry = TestHelpers.xContentRegistry(); + detectionIndices = mock(ADIndexManagement.class); + detectorId = randomAlphaOfLength(10); + taskId = randomAlphaOfLength(10); + detectorContent = "{\"_index\":\".opendistro-anomaly-detectors\",\"_type\":\"_doc\",\"_id\":\"" + + detectorId + + "\",\"_version\":1,\"_seq_no\":1,\"_primary_term\":51,\"found\":true,\"_source\":{\"name\":\"old_r3\"," + + "\"description\":\"nab_ec2_cpu_utilization_24ae8d\",\"time_field\":\"timestamp\",\"indices\":" + + "[\"nab_ec2_cpu_utilization_24ae8d\"],\"filter_query\":{\"match_all\":{\"boost\":1}}," + + "\"detection_interval\":{\"period\":{\"interval\":1,\"unit\":\"Minutes\"}},\"window_delay\":" + + "{\"period\":{\"interval\":1,\"unit\":\"Minutes\"}},\"shingle_size\":8,\"schema_version\":0," + + "\"feature_attributes\":[{\"feature_id\":\"-nTqeXsBxGq4rqj0VvQy\",\"feature_name\":\"F1\"," + + "\"feature_enabled\":true,\"aggregation_query\":{\"f_1\":{\"sum\":{\"field\":\"value\"}}}}]," + + "\"last_update_time\":1629838005821,\"detector_type\":\"REALTIME_SINGLE_ENTITY\"}}"; + jobContent = "{\"_index\":\".opendistro-anomaly-detector-jobs\",\"_type\":\"_doc\",\"_id\":\"" + + detectorId + + "\",\"_score\":1,\"_source\":{\"name\":\"" + + detectorId + + "\",\"schedule\":{\"interval\":{\"start_time\":1629838017881,\"period\":1,\"unit\":\"Minutes\"}}," + + "\"window_delay\":{\"period\":{\"interval\":1,\"unit\":\"Minutes\"}},\"enabled\":true," + + "\"enabled_time\":1629838017881,\"last_update_time\":1629841634355,\"lock_duration_seconds\":60," + + "\"disabled_time\":1629841634355}}"; + indexResponseContent = "{\"_index\":\".opendistro-anomaly-detection-state\",\"_type\":\"_doc\",\"_id\":\"" + + taskId + + "\",\"_version\":1,\"result\":\"created\",\"_shards\":{\"total\":2,\"successful\":2,\"failed\":0}," + + "\"_seq_no\":0,\"_primary_term\":1}"; + internalError = "{\"_index\":\".opendistro-anomaly-detection-state\",\"_type\":\"_doc\",\"_id\":" + + "\"" + + detectorId + + "\",\"_version\":1,\"_seq_no\":10,\"_primary_term\":2,\"found\":true," + + "\"_source\":{\"last_update_time\":1629860362885,\"error\":\"test error\"}}"; + + adDataMigrator = spy(new ADDataMigrator(client, clusterService, namedXContentRegistry, detectionIndices)); + } + + public void testMigrateDataWithNullJobResponse() { + when(detectionIndices.doesJobIndexExist()).thenReturn(true); + when(detectionIndices.doesStateIndexExist()).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).search(any(), any()); + + adDataMigrator.migrateData(); + verify(adDataMigrator, never()).backfillRealtimeTask(any(), anyBoolean()); + } + + public void testMigrateDataWithInitingDetectionStateIndexFailure() { + when(detectionIndices.doesJobIndexExist()).thenReturn(true); + when(detectionIndices.doesStateIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new RuntimeException("test")); + return null; + }).when(detectionIndices).initStateIndex(any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).search(any(), any()); + + adDataMigrator.migrateData(); + verify(adDataMigrator, never()).migrateDetectorInternalStateToRealtimeTask(); + } + + public void testMigrateDataWithInitingDetectionStateIndexAlreadyExists() { + when(detectionIndices.doesJobIndexExist()).thenReturn(true); + when(detectionIndices.doesStateIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new ResourceAlreadyExistsException("test")); + return null; + }).when(detectionIndices).initStateIndex(any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).search(any(), any()); + + adDataMigrator.migrateData(); + verify(adDataMigrator, times(1)).migrateDetectorInternalStateToRealtimeTask(); + } + + public void testMigrateDataWithInitingDetectionStateIndexNotAcknowledged() { + when(detectionIndices.doesJobIndexExist()).thenReturn(true); + when(detectionIndices.doesStateIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(new CreateIndexResponse(false, false, DETECTION_STATE_INDEX)); + return null; + }).when(detectionIndices).initStateIndex(any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).search(any(), any()); + + adDataMigrator.migrateData(); + verify(adDataMigrator, never()).migrateDetectorInternalStateToRealtimeTask(); + } + + public void testMigrateDataWithInitingDetectionStateIndexAcknowledged() { + when(detectionIndices.doesJobIndexExist()).thenReturn(true); + when(detectionIndices.doesStateIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(new CreateIndexResponse(true, false, DETECTION_STATE_INDEX)); + return null; + }).when(detectionIndices).initStateIndex(any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).search(any(), any()); + + adDataMigrator.migrateData(); + verify(adDataMigrator, times(1)).migrateDetectorInternalStateToRealtimeTask(); + } + + public void testMigrateDataWithEmptyJobResponse() { + when(detectionIndices.doesJobIndexExist()).thenReturn(true); + when(detectionIndices.doesStateIndexExist()).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), Float.NaN); + InternalSearchResponse response = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 1 + ); + SearchResponse searchResponse = new SearchResponse( + response, + null, + 1, + 1, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + adDataMigrator.migrateData(); + verify(adDataMigrator, never()).backfillRealtimeTask(any(), anyBoolean()); + } + + public void testMigrateDataWithNormalJobResponseButMissingDetector() { + when(detectionIndices.doesJobIndexExist()).thenReturn(true); + when(detectionIndices.doesStateIndexExist()).thenReturn(true); + + doAnswer(invocation -> { + // Return correct AD job when search job index + ActionListener listener = invocation.getArgument(1); + String detectorId = randomAlphaOfLength(10); + SearchHit job = SearchHit.fromXContent(TestHelpers.parser(jobContent)); + SearchHits searchHits = new SearchHits(new SearchHit[] { job }, new TotalHits(2, TotalHits.Relation.EQUAL_TO), Float.NaN); + InternalSearchResponse response = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 1 + ); + SearchResponse searchResponse = new SearchResponse( + response, + null, + 1, + 1, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + listener.onResponse(searchResponse); + return null; + }).doAnswer(invocation -> { + // Return null when search realtime tasks + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).search(any(), any()); + + // Return null when get detector and internal error from index. + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).get(any(), any()); + + adDataMigrator.migrateData(); + verify(adDataMigrator, times(2)).backfillRealtimeTask(any(), anyBoolean()); + verify(client, never()).index(any(), any()); + } + + public void testMigrateDataWithNormalJobResponseAndExistingDetector() { + when(detectionIndices.doesJobIndexExist()).thenReturn(true); + when(detectionIndices.doesStateIndexExist()).thenReturn(true); + + String detectorId = randomAlphaOfLength(10); + doAnswer(invocation -> { + // Return correct AD job when search job index + ActionListener listener = invocation.getArgument(1); + SearchHit job1 = SearchHit.fromXContent(TestHelpers.parser(jobContent)); + SearchHits searchHits = new SearchHits(new SearchHit[] { job1 }, new TotalHits(2, TotalHits.Relation.EQUAL_TO), Float.NaN); + InternalSearchResponse response = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 1 + ); + SearchResponse searchResponse = new SearchResponse( + response, + null, + 1, + 1, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + listener.onResponse(searchResponse); + return null; + }).doAnswer(invocation -> { + // Return null when search realtime tasks + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).search(any(), any()); + + doAnswer(invocation -> { + // Return null when get detector internal error from index. + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).doAnswer(invocation -> { + // Return correct detector when get detector index. + ActionListener listener = invocation.getArgument(1); + XContentParser parser = TestHelpers.parser(detectorContent, false); + GetResponse getResponse = GetResponse.fromXContent(parser); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + String taskId = randomAlphaOfLength(5); + IndexResponse indexResponse = IndexResponse.fromXContent(TestHelpers.parser(indexResponseContent, false)); + listener.onResponse(indexResponse); + return null; + }).when(client).index(any(), any()); + DiscoveryNode localNode = createNode("localNodeId"); + doReturn(localNode).when(clusterService).localNode(); + + adDataMigrator.migrateData(); + verify(adDataMigrator, times(2)).backfillRealtimeTask(any(), anyBoolean()); + verify(client, times(1)).index(any(), any()); + } + + public void testMigrateDataWithNormalJobResponse_ExistingDetector_ExistingInternalError() { + when(detectionIndices.doesJobIndexExist()).thenReturn(true); + when(detectionIndices.doesStateIndexExist()).thenReturn(true); + + String detectorId = randomAlphaOfLength(10); + doAnswer(invocation -> { + // Return correct AD job when search job index + ActionListener listener = invocation.getArgument(1); + SearchHit job1 = SearchHit.fromXContent(TestHelpers.parser(jobContent)); + SearchHits searchHits = new SearchHits(new SearchHit[] { job1 }, new TotalHits(2, TotalHits.Relation.EQUAL_TO), Float.NaN); + InternalSearchResponse response = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 1 + ); + SearchResponse searchResponse = new SearchResponse( + response, + null, + 1, + 1, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + listener.onResponse(searchResponse); + return null; + }).doAnswer(invocation -> { + // Return null when search realtime tasks + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).search(any(), any()); + + doAnswer(invocation -> { + // Return null when get detector internal error from index. + ActionListener listener = invocation.getArgument(1); + XContentParser parser = TestHelpers.parser(internalError, false); + GetResponse getResponse = GetResponse.fromXContent(parser); + listener.onResponse(getResponse); + return null; + }).doAnswer(invocation -> { + // Return correct detector when get detector index. + ActionListener listener = invocation.getArgument(1); + XContentParser parser = TestHelpers.parser(detectorContent, false); + GetResponse getResponse = GetResponse.fromXContent(parser); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + String taskId = randomAlphaOfLength(5); + IndexResponse indexResponse = IndexResponse.fromXContent(TestHelpers.parser(indexResponseContent, false)); + listener.onResponse(indexResponse); + return null; + }).when(client).index(any(), any()); + DiscoveryNode localNode = createNode("localNodeId"); + doReturn(localNode).when(clusterService).localNode(); + + adDataMigrator.migrateData(); + verify(adDataMigrator, times(2)).backfillRealtimeTask(any(), anyBoolean()); + verify(client, times(1)).index(any(), any()); + } + + public void testMigrateDataTwice() { + adDataMigrator.migrateData(); + adDataMigrator.migrateData(); + verify(detectionIndices, times(1)).doesJobIndexExist(); + } + + public void testMigrateDataWithNoAvailableShardsException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener + .onFailure( + new NoShardAvailableActionException(ShardId.fromString("[.opendistro-anomaly-detector-jobs][1]"), "all shards failed") + ); + return null; + }).when(client).search(any(), any()); + when(detectionIndices.doesJobIndexExist()).thenReturn(true); + when(detectionIndices.doesStateIndexExist()).thenReturn(true); + + adDataMigrator.migrateData(); + assertFalse(adDataMigrator.isMigrated()); + } + + public void testMigrateDataWithIndexNotFoundException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new IndexNotFoundException(CommonName.JOB_INDEX)); + return null; + }).when(client).search(any(), any()); + when(detectionIndices.doesJobIndexExist()).thenReturn(true); + when(detectionIndices.doesStateIndexExist()).thenReturn(true); + + adDataMigrator.migrateData(); + verify(adDataMigrator, never()).backfillRealtimeTask(any(), anyBoolean()); + assertTrue(adDataMigrator.isMigrated()); + } + + public void testMigrateDataWithUnknownException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("test unknown exception")); + return null; + }).when(client).search(any(), any()); + when(detectionIndices.doesJobIndexExist()).thenReturn(true); + when(detectionIndices.doesStateIndexExist()).thenReturn(true); + + adDataMigrator.migrateData(); + verify(adDataMigrator, never()).backfillRealtimeTask(any(), anyBoolean()); + assertTrue(adDataMigrator.isMigrated()); + } +} diff --git a/src/test/java/org/opensearch/ad/cluster/ADVersionUtilTests.java-e b/src/test/java/org/opensearch/ad/cluster/ADVersionUtilTests.java-e new file mode 100644 index 000000000..aa5fcc55b --- /dev/null +++ b/src/test/java/org/opensearch/ad/cluster/ADVersionUtilTests.java-e @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.cluster; + +import org.opensearch.Version; +import org.opensearch.ad.ADUnitTestCase; + +public class ADVersionUtilTests extends ADUnitTestCase { + + public void testParseVersionFromString() { + Version version = ADVersionUtil.fromString("2.1.0.0"); + assertEquals(Version.V_2_1_0, version); + + version = ADVersionUtil.fromString("2.1.0"); + assertEquals(Version.V_2_1_0, version); + } + + public void testParseVersionFromStringWithNull() { + expectThrows(IllegalArgumentException.class, () -> ADVersionUtil.fromString(null)); + } + + public void testParseVersionFromStringWithWrongFormat() { + expectThrows(IllegalArgumentException.class, () -> ADVersionUtil.fromString("1.1")); + } +} diff --git a/src/test/java/org/opensearch/ad/cluster/ClusterManagerEventListenerTests.java-e b/src/test/java/org/opensearch/ad/cluster/ClusterManagerEventListenerTests.java-e new file mode 100644 index 000000000..637c5e10e --- /dev/null +++ b/src/test/java/org/opensearch/ad/cluster/ClusterManagerEventListenerTests.java-e @@ -0,0 +1,130 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.cluster; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.time.Clock; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Locale; + +import org.junit.Before; +import org.opensearch.ad.cluster.diskcleanup.ModelCheckpointIndexRetention; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.component.LifecycleListener; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.threadpool.Scheduler.Cancellable; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +public class ClusterManagerEventListenerTests extends AbstractTimeSeriesTest { + private ClusterService clusterService; + private ThreadPool threadPool; + private Client client; + private Clock clock; + private Cancellable hourlyCancellable; + private Cancellable checkpointIndexRetentionCancellable; + private ClusterManagerEventListener clusterManagerService; + private ClientUtil clientUtil; + private DiscoveryNodeFilterer nodeFilter; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + clusterService = mock(ClusterService.class); + ClusterSettings settings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.CHECKPOINT_TTL))) + ); + when(clusterService.getClusterSettings()).thenReturn(settings); + + threadPool = mock(ThreadPool.class); + hourlyCancellable = mock(Cancellable.class); + checkpointIndexRetentionCancellable = mock(Cancellable.class); + when(threadPool.scheduleWithFixedDelay(any(HourlyCron.class), any(TimeValue.class), any(String.class))) + .thenReturn(hourlyCancellable); + when(threadPool.scheduleWithFixedDelay(any(ModelCheckpointIndexRetention.class), any(TimeValue.class), any(String.class))) + .thenReturn(checkpointIndexRetentionCancellable); + client = mock(Client.class); + clock = mock(Clock.class); + clientUtil = mock(ClientUtil.class); + HashMap ignoredAttributes = new HashMap(); + ignoredAttributes.put(ADCommonName.BOX_TYPE_KEY, ADCommonName.WARM_BOX_TYPE); + nodeFilter = new DiscoveryNodeFilterer(clusterService); + + clusterManagerService = new ClusterManagerEventListener( + clusterService, + threadPool, + client, + clock, + clientUtil, + nodeFilter, + AnomalyDetectorSettings.CHECKPOINT_TTL, + Settings.EMPTY + ); + } + + public void testOnOffClusterManager() { + clusterManagerService.onClusterManager(); + assertThat(hourlyCancellable, is(notNullValue())); + assertThat(checkpointIndexRetentionCancellable, is(notNullValue())); + assertTrue(!clusterManagerService.getHourlyCron().isCancelled()); + assertTrue(!clusterManagerService.getCheckpointIndexRetentionCron().isCancelled()); + clusterManagerService.offClusterManager(); + assertThat(clusterManagerService.getCheckpointIndexRetentionCron(), is(nullValue())); + assertThat(clusterManagerService.getHourlyCron(), is(nullValue())); + } + + public void testBeforeStop() { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length == 1 + ); + + LifecycleListener listener = null; + if (args[0] instanceof LifecycleListener) { + listener = (LifecycleListener) args[0]; + } + + assertTrue(listener != null); + listener.beforeStop(); + + return null; + }).when(clusterService).addLifecycleListener(any()); + + clusterManagerService.onClusterManager(); + assertThat(clusterManagerService.getCheckpointIndexRetentionCron(), is(nullValue())); + assertThat(clusterManagerService.getHourlyCron(), is(nullValue())); + clusterManagerService.offClusterManager(); + assertThat(clusterManagerService.getCheckpointIndexRetentionCron(), is(nullValue())); + assertThat(clusterManagerService.getHourlyCron(), is(nullValue())); + } +} diff --git a/src/test/java/org/opensearch/ad/cluster/DailyCronTests.java-e b/src/test/java/org/opensearch/ad/cluster/DailyCronTests.java-e new file mode 100644 index 000000000..63d48ef3c --- /dev/null +++ b/src/test/java/org/opensearch/ad/cluster/DailyCronTests.java-e @@ -0,0 +1,99 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.cluster; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.time.Clock; +import java.time.Duration; +import java.util.Arrays; +import java.util.Locale; + +import org.opensearch.OpenSearchException; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.reindex.BulkByScrollResponse; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.timeseries.AbstractTimeSeriesTest; + +public class DailyCronTests extends AbstractTimeSeriesTest { + + enum DailyCronTestExecutionMode { + NORMAL, + INDEX_NOT_EXIST, + FAIL + } + + @Override + public void setUp() throws Exception { + super.setUp(); + super.setUpLog4jForJUnit(DailyCron.class); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + super.tearDownLog4jForJUnit(); + } + + @SuppressWarnings("unchecked") + private void templateDailyCron(DailyCronTestExecutionMode mode) { + Clock clock = mock(Clock.class); + ClientUtil clientUtil = mock(ClientUtil.class); + DailyCron cron = new DailyCron(clock, Duration.ofHours(24), clientUtil); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length == 3 + ); + assertTrue(args[2] instanceof ActionListener); + + ActionListener listener = (ActionListener) args[2]; + + if (mode == DailyCronTestExecutionMode.INDEX_NOT_EXIST) { + listener.onFailure(new IndexNotFoundException("foo", "bar")); + } else if (mode == DailyCronTestExecutionMode.FAIL) { + listener.onFailure(new OpenSearchException("bar")); + } else { + BulkByScrollResponse deleteByQueryResponse = mock(BulkByScrollResponse.class); + when(deleteByQueryResponse.getDeleted()).thenReturn(10L); + listener.onResponse(deleteByQueryResponse); + } + + return null; + }).when(clientUtil).execute(eq(DeleteByQueryAction.INSTANCE), any(), any()); + + cron.run(); + } + + public void testNormal() { + templateDailyCron(DailyCronTestExecutionMode.NORMAL); + assertTrue(testAppender.containsMessage(DailyCron.CHECKPOINT_DELETED_MSG)); + } + + public void testCheckpointNotExist() { + templateDailyCron(DailyCronTestExecutionMode.INDEX_NOT_EXIST); + assertTrue(testAppender.containsMessage(DailyCron.CHECKPOINT_NOT_EXIST_MSG)); + } + + public void testFail() { + templateDailyCron(DailyCronTestExecutionMode.FAIL); + assertTrue(testAppender.containsMessage(DailyCron.CANNOT_DELETE_OLD_CHECKPOINT_MSG)); + } +} diff --git a/src/test/java/org/opensearch/ad/cluster/HashRingTests.java b/src/test/java/org/opensearch/ad/cluster/HashRingTests.java index b6474678c..e85051dd9 100644 --- a/src/test/java/org/opensearch/ad/cluster/HashRingTests.java +++ b/src/test/java/org/opensearch/ad/cluster/HashRingTests.java @@ -50,6 +50,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.plugins.PluginInfo; +import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import com.google.common.collect.ImmutableList; @@ -248,7 +249,7 @@ private NodeInfo createNodeInfo(DiscoveryNode node, String version) { plugins .add( new PluginInfo( - ADCommonName.AD_PLUGIN_NAME, + CommonName.TIME_SERIES_PLUGIN_NAME, randomAlphaOfLengthBetween(3, 10), version, Version.CURRENT, diff --git a/src/test/java/org/opensearch/ad/cluster/HashRingTests.java-e b/src/test/java/org/opensearch/ad/cluster/HashRingTests.java-e new file mode 100644 index 000000000..e85051dd9 --- /dev/null +++ b/src/test/java/org/opensearch/ad/cluster/HashRingTests.java-e @@ -0,0 +1,284 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.cluster; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptyMap; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.COOLDOWN_MINUTES; + +import java.net.UnknownHostException; +import java.time.Clock; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; + +import org.junit.Before; +import org.opensearch.Build; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.admin.cluster.node.info.NodeInfo; +import org.opensearch.action.admin.cluster.node.info.NodesInfoResponse; +import org.opensearch.action.admin.cluster.node.info.PluginsAndModules; +import org.opensearch.ad.ADUnitTestCase; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.plugins.PluginInfo; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +public class HashRingTests extends ADUnitTestCase { + + private ClusterService clusterService; + private DiscoveryNodeFilterer nodeFilter; + private Settings settings; + private Clock clock; + private Client client; + private ClusterAdminClient clusterAdminClient; + private AdminClient adminClient; + private ADDataMigrator dataMigrator; + private HashRing hashRing; + private DiscoveryNodes.Delta delta; + private String localNodeId; + private String newNodeId; + private String warmNodeId; + private DiscoveryNode localNode; + private DiscoveryNode newNode; + private DiscoveryNode warmNode; + private ModelManager modelManager; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + + localNodeId = "localNode"; + localNode = createNode(localNodeId, "127.0.0.1", 9200, emptyMap()); + newNodeId = "newNode"; + newNode = createNode(newNodeId, "127.0.0.2", 9201, emptyMap()); + warmNodeId = "warmNode"; + warmNode = createNode(warmNodeId, "127.0.0.3", 9202, ImmutableMap.of(ADCommonName.BOX_TYPE_KEY, ADCommonName.WARM_BOX_TYPE)); + + settings = Settings.builder().put(COOLDOWN_MINUTES.getKey(), TimeValue.timeValueSeconds(5)).build(); + ClusterSettings clusterSettings = clusterSetting(settings, COOLDOWN_MINUTES); + clusterService = spy(new ClusterService(settings, clusterSettings, null)); + + nodeFilter = spy(new DiscoveryNodeFilterer(clusterService)); + client = mock(Client.class); + dataMigrator = mock(ADDataMigrator.class); + + clock = mock(Clock.class); + when(clock.millis()).thenReturn(700000L); + + delta = mock(DiscoveryNodes.Delta.class); + + adminClient = mock(AdminClient.class); + when(client.admin()).thenReturn(adminClient); + clusterAdminClient = mock(ClusterAdminClient.class); + when(adminClient.cluster()).thenReturn(clusterAdminClient); + + String modelId = "123_model_threshold"; + modelManager = mock(ModelManager.class); + doAnswer(invocation -> { + Set res = new HashSet<>(); + res.add(modelId); + return res; + }).when(modelManager).getAllModelIds(); + + hashRing = spy(new HashRing(nodeFilter, clock, settings, client, clusterService, dataMigrator, modelManager)); + } + + public void testGetOwningNodeWithEmptyResult() throws UnknownHostException { + DiscoveryNode node1 = createNode(Integer.toString(1), "127.0.0.4", 9204, emptyMap()); + doReturn(node1).when(clusterService).localNode(); + + Optional node = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD("http-latency-rcf-1"); + assertFalse(node.isPresent()); + } + + public void testGetOwningNode() throws UnknownHostException { + List addedNodes = setupNodeDelta(); + + // Add first node, + hashRing.buildCircles(delta, ActionListener.wrap(r -> { + Optional node = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD("http-latency-rcf-1"); + assertTrue(node.isPresent()); + assertTrue(asList(newNodeId, localNodeId).contains(node.get().getId())); + DiscoveryNode[] nodesWithSameLocalAdVersion = hashRing.getNodesWithSameLocalAdVersion(); + Set nodesWithSameLocalAdVersionIds = new HashSet<>(); + for (DiscoveryNode n : nodesWithSameLocalAdVersion) { + nodesWithSameLocalAdVersionIds.add(n.getId()); + } + assertFalse("Should not build warm node into hash ring", nodesWithSameLocalAdVersionIds.contains(warmNodeId)); + assertEquals("Wrong hash ring size", 2, nodesWithSameLocalAdVersion.length); + assertEquals( + "Wrong hash ring size for historical analysis", + 2, + hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, false).size() + ); + // Circles for realtime AD will change as it's eligible to build for when its empty + assertEquals("Wrong hash ring size for realtime AD", 2, hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, true).size()); + }, e -> { + logger.error("building hash ring failed", e); + assertFalse("Build hash ring failed", true); + })); + + // Second new node joins cluster, test realtime circles will not update. + String newNodeId2 = "newNode2"; + DiscoveryNode newNode2 = createNode(newNodeId2, "127.0.0.4", 9200, emptyMap()); + addedNodes.add(newNode2); + when(delta.addedNodes()).thenReturn(addedNodes); + setupClusterAdminClient(localNode, newNode, newNode2); + hashRing.buildCircles(delta, ActionListener.wrap(r -> { + assertEquals( + "Wrong hash ring size for historical analysis", + 3, + hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, false).size() + ); + // Circles for realtime AD will not change as it's eligible to rebuild + assertEquals("Wrong hash ring size for realtime AD", 2, hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, true).size()); + }, e -> { + logger.error("building hash ring failed", e); + + assertFalse("Build hash ring failed", true); + })); + + // Mock it's eligible to rebuild circles for realtime AD, then add new node. Realtime circles should change. + when(hashRing.eligibleToRebuildCirclesForRealtimeAD()).thenReturn(true); + String newNodeId3 = "newNode3"; + DiscoveryNode newNode3 = createNode(newNodeId3, "127.0.0.5", 9200, emptyMap()); + addedNodes.add(newNode3); + when(delta.addedNodes()).thenReturn(addedNodes); + setupClusterAdminClient(localNode, newNode, newNode2, newNode3); + hashRing.buildCircles(delta, ActionListener.wrap(r -> { + assertEquals( + "Wrong hash ring size for historical analysis", + 4, + hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, false).size() + ); + assertEquals("Wrong hash ring size for realtime AD", 4, hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, true).size()); + }, e -> { + logger.error("building hash ring failed", e); + assertFalse("Failed to build hash ring", true); + })); + } + + public void testGetAllEligibleDataNodesWithKnownAdVersionAndGetNodeByAddress() { + setupNodeDelta(); + hashRing.getAllEligibleDataNodesWithKnownAdVersion(nodes -> { + assertEquals("Wrong hash ring size for historical analysis", 2, nodes.length); + Optional node = hashRing.getNodeByAddress(newNode.getAddress()); + assertTrue(node.isPresent()); + assertEquals(newNodeId, node.get().getId()); + }, ActionListener.wrap(r -> {}, e -> { assertFalse("Failed to build hash ring", true); })); + } + + public void testBuildAndGetOwningNodeWithSameLocalAdVersion() { + setupNodeDelta(); + hashRing + .buildAndGetOwningNodeWithSameLocalAdVersion( + "testModelId", + node -> { assertTrue(node.isPresent()); }, + ActionListener.wrap(r -> {}, e -> { assertFalse("Failed to build hash ring", true); }) + ); + } + + private List setupNodeDelta() { + List addedNodes = new ArrayList<>(); + addedNodes.add(newNode); + + List removedNodes = asList(); + + when(delta.removed()).thenReturn(false); + when(delta.added()).thenReturn(true); + when(delta.removedNodes()).thenReturn(removedNodes); + when(delta.addedNodes()).thenReturn(addedNodes); + + doReturn(localNode).when(clusterService).localNode(); + setupClusterAdminClient(localNode, newNode, warmNode); + + doReturn(new DiscoveryNode[] { localNode, newNode }).when(nodeFilter).getEligibleDataNodes(); + doReturn(new DiscoveryNode[] { localNode, newNode, warmNode }).when(nodeFilter).getAllNodes(); + return addedNodes; + } + + private void setupClusterAdminClient(DiscoveryNode... nodes) { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + List nodeInfos = new ArrayList<>(); + for (DiscoveryNode node : nodes) { + nodeInfos.add(createNodeInfo(node, "2.1.0.0")); + } + NodesInfoResponse nodesInfoResponse = new NodesInfoResponse(ClusterName.DEFAULT, nodeInfos, ImmutableList.of()); + listener.onResponse(nodesInfoResponse); + return null; + }).when(clusterAdminClient).nodesInfo(any(), any()); + } + + private NodeInfo createNodeInfo(DiscoveryNode node, String version) { + List plugins = new ArrayList<>(); + plugins + .add( + new PluginInfo( + CommonName.TIME_SERIES_PLUGIN_NAME, + randomAlphaOfLengthBetween(3, 10), + version, + Version.CURRENT, + "1.8", + randomAlphaOfLengthBetween(3, 10), + randomAlphaOfLengthBetween(3, 10), + ImmutableList.of(), + randomBoolean() + ) + ); + List modules = new ArrayList<>(); + modules.addAll(plugins); + PluginsAndModules pluginsAndModules = new PluginsAndModules(plugins, modules); + return new NodeInfo( + Version.CURRENT, + Build.CURRENT, + node, + settings, + null, + null, + null, + null, + null, + null, + pluginsAndModules, + null, + null, + null, + null + ); + } +} diff --git a/src/test/java/org/opensearch/ad/cluster/HourlyCronTests.java b/src/test/java/org/opensearch/ad/cluster/HourlyCronTests.java index a3d151c0f..6461a7b3e 100644 --- a/src/test/java/org/opensearch/ad/cluster/HourlyCronTests.java +++ b/src/test/java/org/opensearch/ad/cluster/HourlyCronTests.java @@ -37,7 +37,7 @@ import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; diff --git a/src/test/java/org/opensearch/ad/cluster/HourlyCronTests.java-e b/src/test/java/org/opensearch/ad/cluster/HourlyCronTests.java-e new file mode 100644 index 000000000..6461a7b3e --- /dev/null +++ b/src/test/java/org/opensearch/ad/cluster/HourlyCronTests.java-e @@ -0,0 +1,139 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.cluster; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Locale; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchException; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.FailedNodeException; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.transport.CronAction; +import org.opensearch.ad.transport.CronNodeResponse; +import org.opensearch.ad.transport.CronResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +import test.org.opensearch.ad.util.ClusterCreation; + +public class HourlyCronTests extends AbstractTimeSeriesTest { + + enum HourlyCronTestExecutionMode { + NORMAL, + NODE_FAIL, + ALL_FAIL + } + + @SuppressWarnings("unchecked") + public void templateHourlyCron(HourlyCronTestExecutionMode mode) { + super.setUpLog4jForJUnit(HourlyCron.class); + + ClusterService clusterService = mock(ClusterService.class); + ClusterState state = ClusterCreation.state(1); + when(clusterService.state()).thenReturn(state); + HashMap ignoredAttributes = new HashMap(); + ignoredAttributes.put(ADCommonName.BOX_TYPE_KEY, ADCommonName.WARM_BOX_TYPE); + DiscoveryNodeFilterer nodeFilter = new DiscoveryNodeFilterer(clusterService); + + Client client = mock(Client.class); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length == 3 + ); + assertTrue(args[2] instanceof ActionListener); + + ActionListener listener = (ActionListener) args[2]; + + if (mode == HourlyCronTestExecutionMode.NODE_FAIL) { + listener + .onResponse( + new CronResponse( + new ClusterName("test"), + Collections.singletonList(new CronNodeResponse(state.nodes().getLocalNode())), + Collections.singletonList(new FailedNodeException("foo0", "blah", new OpenSearchException("bar"))) + ) + ); + } else if (mode == HourlyCronTestExecutionMode.ALL_FAIL) { + listener.onFailure(new OpenSearchException("bar")); + } else { + CronNodeResponse nodeResponse = new CronNodeResponse(state.nodes().getLocalNode()); + BytesStreamOutput nodeResponseOut = new BytesStreamOutput(); + nodeResponseOut.setVersion(Version.CURRENT); + nodeResponse.writeTo(nodeResponseOut); + StreamInput siNode = nodeResponseOut.bytes().streamInput(); + + CronNodeResponse nodeResponseRead = new CronNodeResponse(siNode); + + CronResponse response = new CronResponse( + new ClusterName("test"), + Collections.singletonList(nodeResponseRead), + Collections.EMPTY_LIST + ); + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(Version.CURRENT); + response.writeTo(out); + StreamInput si = out.bytes().streamInput(); + CronResponse responseRead = new CronResponse(si); + listener.onResponse(responseRead); + } + + return null; + }).when(client).execute(eq(CronAction.INSTANCE), any(), any()); + + HourlyCron cron = new HourlyCron(client, nodeFilter); + cron.run(); + + Logger LOG = LogManager.getLogger(HourlyCron.class); + LOG.info(testAppender.messages); + if (mode == HourlyCronTestExecutionMode.NODE_FAIL) { + assertTrue(testAppender.containsMessage(HourlyCron.NODE_EXCEPTION_LOG_MSG)); + } else if (mode == HourlyCronTestExecutionMode.ALL_FAIL) { + assertTrue(testAppender.containsMessage(HourlyCron.EXCEPTION_LOG_MSG)); + } else { + assertTrue(testAppender.containsMessage(HourlyCron.SUCCEEDS_LOG_MSG)); + } + + super.tearDownLog4jForJUnit(); + } + + public void testNormal() { + templateHourlyCron(HourlyCronTestExecutionMode.NORMAL); + } + + public void testAllFail() { + templateHourlyCron(HourlyCronTestExecutionMode.ALL_FAIL); + } + + public void testNodeFail() throws Exception { + templateHourlyCron(HourlyCronTestExecutionMode.NODE_FAIL); + } +} diff --git a/src/test/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanupTests.java-e b/src/test/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanupTests.java-e new file mode 100644 index 000000000..1425a5ec3 --- /dev/null +++ b/src/test/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanupTests.java-e @@ -0,0 +1,126 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.cluster.diskcleanup; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import org.apache.logging.log4j.Level; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.core.Logger; +import org.junit.Assert; +import org.mockito.Answers; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionListener; +import org.opensearch.action.admin.indices.stats.CommonStats; +import org.opensearch.action.admin.indices.stats.IndicesStatsResponse; +import org.opensearch.action.admin.indices.stats.ShardStats; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.client.Client; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.store.StoreStats; +import org.opensearch.timeseries.AbstractTimeSeriesTest; + +public class IndexCleanupTests extends AbstractTimeSeriesTest { + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + Client client; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + ClusterService clusterService; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + ClientUtil clientUtil; + + IndexCleanup indexCleanup; + + @Mock + IndicesStatsResponse indicesStatsResponse; + + @Mock + ShardStats shardStats; + + @Mock + CommonStats commonStats; + + @Mock + StoreStats storeStats; + + @Mock + IndicesAdminClient indicesAdminClient; + + @SuppressWarnings("unchecked") + @Override + public void setUp() throws Exception { + super.setUp(); + + MockitoAnnotations.initMocks(this); + when(clusterService.state().getRoutingTable().hasIndex(anyString())).thenReturn(true); + super.setUpLog4jForJUnit(IndexCleanup.class); + indexCleanup = new IndexCleanup(client, clientUtil, clusterService); + when(indicesStatsResponse.getShards()).thenReturn(new ShardStats[] { shardStats }); + when(shardStats.getStats()).thenReturn(commonStats); + when(commonStats.getStore()).thenReturn(storeStats); + when(client.admin().indices()).thenReturn(indicesAdminClient); + when(client.threadPool().getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + listener.onResponse(indicesStatsResponse); + return null; + }).when(indicesAdminClient).stats(any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + super.tearDownLog4jForJUnit(); + } + + public void testDeleteDocsBasedOnShardSizeWithCleanupNeededAsTrue() throws Exception { + long maxShardSize = 1000; + when(storeStats.getSizeInBytes()).thenReturn(maxShardSize + 1); + indexCleanup.deleteDocsBasedOnShardSize("indexname", maxShardSize, null, ActionListener.wrap(result -> { + assertTrue(result); + verify(clientUtil).execute(eq(DeleteByQueryAction.INSTANCE), any(), any()); + }, exception -> { throw new RuntimeException(exception); })); + } + + public void testDeleteDocsBasedOnShardSizeWithCleanupNeededAsFalse() throws Exception { + long maxShardSize = 1000; + when(storeStats.getSizeInBytes()).thenReturn(maxShardSize - 1); + indexCleanup + .deleteDocsBasedOnShardSize( + "indexname", + maxShardSize, + null, + ActionListener.wrap(Assert::assertFalse, exception -> { throw new RuntimeException(exception); }) + ); + } + + public void testDeleteDocsBasedOnShardSizeIndexNotExisted() throws Exception { + when(clusterService.state().getRoutingTable().hasIndex(anyString())).thenReturn(false); + Logger logger = (Logger) LogManager.getLogger(IndexCleanup.class); + logger.setLevel(Level.DEBUG); + indexCleanup.deleteDocsBasedOnShardSize("indexname", 1000, null, null); + assertTrue(testAppender.containsMessage("skip as the index:indexname doesn't exist")); + } +} diff --git a/src/test/java/org/opensearch/ad/cluster/diskcleanup/ModelCheckpointIndexRetentionTests.java-e b/src/test/java/org/opensearch/ad/cluster/diskcleanup/ModelCheckpointIndexRetentionTests.java-e new file mode 100644 index 000000000..0222a4d47 --- /dev/null +++ b/src/test/java/org/opensearch/ad/cluster/diskcleanup/ModelCheckpointIndexRetentionTests.java-e @@ -0,0 +1,99 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.cluster.diskcleanup; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.time.Clock; +import java.time.Duration; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.timeseries.AbstractTimeSeriesTest; + +public class ModelCheckpointIndexRetentionTests extends AbstractTimeSeriesTest { + + Duration defaultCheckpointTtl = Duration.ofDays(3); + + Clock clock = Clock.systemUTC(); + + @Mock + IndexCleanup indexCleanup; + + ModelCheckpointIndexRetention modelCheckpointIndexRetention; + + @SuppressWarnings("unchecked") + @Before + public void setUp() throws Exception { + super.setUp(); + super.setUpLog4jForJUnit(IndexCleanup.class); + MockitoAnnotations.initMocks(this); + modelCheckpointIndexRetention = new ModelCheckpointIndexRetention(defaultCheckpointTtl, clock, indexCleanup); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[2]; + listener.onResponse(1L); + return null; + }).when(indexCleanup).deleteDocsByQuery(anyString(), any(), any()); + } + + @After + public void tearDown() throws Exception { + super.tearDown(); + super.tearDownLog4jForJUnit(); + } + + @SuppressWarnings("unchecked") + @Test + public void testRunWithCleanupAsNeeded() throws Exception { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[3]; + listener.onResponse(true); + return null; + }) + .when(indexCleanup) + .deleteDocsBasedOnShardSize(eq(ADCommonName.CHECKPOINT_INDEX_NAME), eq(50 * 1024 * 1024 * 1024L), any(), any()); + + modelCheckpointIndexRetention.run(); + verify(indexCleanup, times(2)) + .deleteDocsBasedOnShardSize(eq(ADCommonName.CHECKPOINT_INDEX_NAME), eq(50 * 1024 * 1024 * 1024L), any(), any()); + verify(indexCleanup).deleteDocsByQuery(eq(ADCommonName.CHECKPOINT_INDEX_NAME), any(), any()); + } + + @SuppressWarnings("unchecked") + @Test + public void testRunWithCleanupAsFalse() throws Exception { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[3]; + listener.onResponse(false); + return null; + }) + .when(indexCleanup) + .deleteDocsBasedOnShardSize(eq(ADCommonName.CHECKPOINT_INDEX_NAME), eq(50 * 1024 * 1024 * 1024L), any(), any()); + + modelCheckpointIndexRetention.run(); + verify(indexCleanup).deleteDocsBasedOnShardSize(eq(ADCommonName.CHECKPOINT_INDEX_NAME), eq(50 * 1024 * 1024 * 1024L), any(), any()); + verify(indexCleanup).deleteDocsByQuery(eq(ADCommonName.CHECKPOINT_INDEX_NAME), any(), any()); + } +} diff --git a/src/test/java/org/opensearch/ad/common/exception/ADTaskCancelledExceptionTests.java-e b/src/test/java/org/opensearch/ad/common/exception/ADTaskCancelledExceptionTests.java-e new file mode 100644 index 000000000..d66573379 --- /dev/null +++ b/src/test/java/org/opensearch/ad/common/exception/ADTaskCancelledExceptionTests.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.common.exception; + +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.common.exception.TaskCancelledException; + +public class ADTaskCancelledExceptionTests extends OpenSearchTestCase { + + public void testConstructor() { + String message = randomAlphaOfLength(5); + String user = randomAlphaOfLength(5); + TaskCancelledException exception = new TaskCancelledException(message, user); + assertEquals(message, exception.getMessage()); + assertEquals(user, exception.getCancelledBy()); + } +} diff --git a/src/test/java/org/opensearch/ad/common/exception/JsonPathNotFoundException.java-e b/src/test/java/org/opensearch/ad/common/exception/JsonPathNotFoundException.java-e new file mode 100644 index 000000000..bd8daf936 --- /dev/null +++ b/src/test/java/org/opensearch/ad/common/exception/JsonPathNotFoundException.java-e @@ -0,0 +1,20 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.common.exception; + +public class JsonPathNotFoundException extends Exception { + + public JsonPathNotFoundException() { + super("Invalid Json path"); + } + +} diff --git a/src/test/java/org/opensearch/ad/common/exception/LimitExceededExceptionTests.java-e b/src/test/java/org/opensearch/ad/common/exception/LimitExceededExceptionTests.java-e new file mode 100644 index 000000000..37b3770ff --- /dev/null +++ b/src/test/java/org/opensearch/ad/common/exception/LimitExceededExceptionTests.java-e @@ -0,0 +1,29 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.common.exception; + +import static org.junit.Assert.assertEquals; + +import org.junit.Test; +import org.opensearch.timeseries.common.exception.LimitExceededException; + +public class LimitExceededExceptionTests { + + @Test + public void testConstructorWithIdAndExplanation() { + String id = "test id"; + String message = "test message"; + LimitExceededException limitExceeded = new LimitExceededException(id, message); + assertEquals(id, limitExceeded.getConfigId()); + assertEquals(message, limitExceeded.getMessage()); + } +} diff --git a/src/test/java/org/opensearch/ad/common/exception/NotSerializedADExceptionNameTests.java b/src/test/java/org/opensearch/ad/common/exception/NotSerializedADExceptionNameTests.java index b76e94c64..0e544b409 100644 --- a/src/test/java/org/opensearch/ad/common/exception/NotSerializedADExceptionNameTests.java +++ b/src/test/java/org/opensearch/ad/common/exception/NotSerializedADExceptionNameTests.java @@ -13,7 +13,7 @@ import java.util.Optional; -import org.opensearch.common.io.stream.NotSerializableExceptionWrapper; +import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.timeseries.common.exception.ClientException; import org.opensearch.timeseries.common.exception.DuplicateTaskException; diff --git a/src/test/java/org/opensearch/ad/common/exception/NotSerializedADExceptionNameTests.java-e b/src/test/java/org/opensearch/ad/common/exception/NotSerializedADExceptionNameTests.java-e new file mode 100644 index 000000000..0e544b409 --- /dev/null +++ b/src/test/java/org/opensearch/ad/common/exception/NotSerializedADExceptionNameTests.java-e @@ -0,0 +1,74 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.common.exception; + +import java.util.Optional; + +import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.common.exception.ClientException; +import org.opensearch.timeseries.common.exception.DuplicateTaskException; +import org.opensearch.timeseries.common.exception.InternalFailure; +import org.opensearch.timeseries.common.exception.NotSerializedExceptionName; +import org.opensearch.timeseries.common.exception.TaskCancelledException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.common.exception.ValidationException; + +public class NotSerializedADExceptionNameTests extends OpenSearchTestCase { + public void testConvertAnomalyDetectionException() { + Optional converted = NotSerializedExceptionName + .convertWrappedTimeSeriesException(new NotSerializableExceptionWrapper(new TimeSeriesException("", "")), ""); + assertTrue(converted.isPresent()); + assertTrue(converted.get() instanceof TimeSeriesException); + } + + public void testConvertInternalFailure() { + Optional converted = NotSerializedExceptionName + .convertWrappedTimeSeriesException(new NotSerializableExceptionWrapper(new InternalFailure("", "")), ""); + assertTrue(converted.isPresent()); + assertTrue(converted.get() instanceof InternalFailure); + } + + public void testConvertClientException() { + Optional converted = NotSerializedExceptionName + .convertWrappedTimeSeriesException(new NotSerializableExceptionWrapper(new ClientException("", "")), ""); + assertTrue(converted.isPresent()); + assertTrue(converted.get() instanceof ClientException); + } + + public void testConvertADTaskCancelledException() { + Optional converted = NotSerializedExceptionName + .convertWrappedTimeSeriesException(new NotSerializableExceptionWrapper(new TaskCancelledException("", "")), ""); + assertTrue(converted.isPresent()); + assertTrue(converted.get() instanceof TaskCancelledException); + } + + public void testConvertDuplicateTaskException() { + Optional converted = NotSerializedExceptionName + .convertWrappedTimeSeriesException(new NotSerializableExceptionWrapper(new DuplicateTaskException("")), ""); + assertTrue(converted.isPresent()); + assertTrue(converted.get() instanceof DuplicateTaskException); + } + + public void testConvertADValidationException() { + Optional converted = NotSerializedExceptionName + .convertWrappedTimeSeriesException(new NotSerializableExceptionWrapper(new ValidationException("", null, null)), ""); + assertTrue(converted.isPresent()); + assertTrue(converted.get() instanceof ValidationException); + } + + public void testUnknownException() { + Optional converted = NotSerializedExceptionName + .convertWrappedTimeSeriesException(new NotSerializableExceptionWrapper(new RuntimeException("")), ""); + assertTrue(!converted.isPresent()); + } +} diff --git a/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java-e b/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java-e new file mode 100644 index 000000000..919b3e068 --- /dev/null +++ b/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java-e @@ -0,0 +1,243 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.e2e; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; +import static org.opensearch.timeseries.TestHelpers.toHttpEntity; + +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.charset.Charset; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.opensearch.ad.ODFERestTestCase; +import org.opensearch.client.Request; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.Response; +import org.opensearch.client.RestClient; +import org.opensearch.client.WarningsHandler; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.TestHelpers; + +import com.google.common.collect.ImmutableList; +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; + +public class AbstractSyntheticDataTest extends ODFERestTestCase { + /** + * In real time AD, we mute a node for a detector if that node keeps returning + * ResourceNotFoundException (5 times in a row). This is a problem for batch mode + * testing as we issue a large amount of requests quickly. Due to the speed, we + * won't be able to finish cold start before the ResourceNotFoundException mutes + * a node. Since our test case has only one node, there is no other nodes to fall + * back on. Here we disable such fault tolerance by setting max retries before + * muting to a large number and the actual wait time during muting to 0. + * + * @throws IOException when failing to create http request body + */ + protected void disableResourceNotFoundFaultTolerence() throws IOException { + XContentBuilder settingCommand = JsonXContent.contentBuilder(); + + settingCommand.startObject(); + settingCommand.startObject("persistent"); + settingCommand.field(MAX_RETRY_FOR_UNRESPONSIVE_NODE.getKey(), 100_000); + settingCommand.field(BACKOFF_MINUTES.getKey(), 0); + settingCommand.endObject(); + settingCommand.endObject(); + Request request = new Request("PUT", "/_cluster/settings"); + request.setJsonEntity(org.opensearch.common.Strings.toString(settingCommand)); + + adminClient().performRequest(request); + } + + protected List getData(String datasetFileName) throws Exception { + JsonArray jsonArray = JsonParser + .parseReader(new FileReader(new File(getClass().getResource(datasetFileName).toURI()), Charset.defaultCharset())) + .getAsJsonArray(); + List list = new ArrayList<>(jsonArray.size()); + jsonArray.iterator().forEachRemaining(i -> list.add(i.getAsJsonObject())); + return list; + } + + protected Map getDetectionResult(String detectorId, Instant begin, Instant end, RestClient client) { + try { + Request request = new Request( + "POST", + String.format(Locale.ROOT, "/_opendistro/_anomaly_detection/detectors/%s/_run", detectorId) + ); + request + .setJsonEntity( + String.format(Locale.ROOT, "{ \"period_start\": %d, \"period_end\": %d }", begin.toEpochMilli(), end.toEpochMilli()) + ); + return entityAsMap(client.performRequest(request)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + protected void bulkIndexTrainData( + String datasetName, + List data, + int trainTestSplit, + RestClient client, + String categoryField + ) throws Exception { + Request request = new Request("PUT", datasetName); + String requestBody = null; + if (Strings.isEmpty(categoryField)) { + requestBody = "{ \"mappings\": { \"properties\": { \"timestamp\": { \"type\": \"date\"}," + + " \"Feature1\": { \"type\": \"double\" }, \"Feature2\": { \"type\": \"double\" } } } }"; + } else { + requestBody = String + .format( + Locale.ROOT, + "{ \"mappings\": { \"properties\": { \"timestamp\": { \"type\": \"date\"}," + + " \"Feature1\": { \"type\": \"double\" }, \"Feature2\": { \"type\": \"double\" }," + + "\"%s\": { \"type\": \"keyword\"} } } }", + categoryField + ); + } + + request.setJsonEntity(requestBody); + setWarningHandler(request, false); + client.performRequest(request); + Thread.sleep(1_000); + + StringBuilder bulkRequestBuilder = new StringBuilder(); + for (int i = 0; i < trainTestSplit; i++) { + bulkRequestBuilder.append("{ \"index\" : { \"_index\" : \"" + datasetName + "\", \"_id\" : \"" + i + "\" } }\n"); + bulkRequestBuilder.append(data.get(i).toString()).append("\n"); + } + TestHelpers + .makeRequest( + client, + "POST", + "_bulk?refresh=true", + null, + toHttpEntity(bulkRequestBuilder.toString()), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + Thread.sleep(1_000); + waitAllSyncheticDataIngested(trainTestSplit, datasetName, client); + } + + protected String createDetector( + String datasetName, + int intervalMinutes, + RestClient client, + String categoryField, + long windowDelayInMins + ) throws Exception { + Request request = new Request("POST", "/_plugins/_anomaly_detection/detectors/"); + String requestBody = null; + if (Strings.isEmpty(categoryField)) { + requestBody = String + .format( + Locale.ROOT, + "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" + + ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_name\": \"feature 1\", \"feature_enabled\": " + + "\"true\", \"aggregation_query\": { \"Feature1\": { \"sum\": { \"field\": \"Feature1\" } } } }, { \"feature_name\"" + + ": \"feature 2\", \"feature_enabled\": \"true\", \"aggregation_query\": { \"Feature2\": { \"sum\": { \"field\": " + + "\"Feature2\" } } } }], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }, " + + "\"window_delay\": { \"period\": {\"interval\": %d, \"unit\": \"MINUTES\"}}," + + "\"schema_version\": 0 }", + datasetName, + intervalMinutes, + windowDelayInMins + ); + } else { + requestBody = String + .format( + Locale.ROOT, + "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" + + ", \"indices\": [\"%s\"], \"feature_attributes\": [{ \"feature_name\": \"feature 1\", \"feature_enabled\": " + + "\"true\", \"aggregation_query\": { \"Feature1\": { \"sum\": { \"field\": \"Feature1\" } } } }, { \"feature_name\"" + + ": \"feature 2\", \"feature_enabled\": \"true\", \"aggregation_query\": { \"Feature2\": { \"sum\": { \"field\": " + + "\"Feature2\" } } } }], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }, " + + "\"category_field\": [\"%s\"], " + + "\"window_delay\": { \"period\": {\"interval\": %d, \"unit\": \"MINUTES\"}}," + + "\"schema_version\": 0 }", + datasetName, + intervalMinutes, + categoryField, + windowDelayInMins + ); + } + + request.setJsonEntity(requestBody); + Map response = entityAsMap(client.performRequest(request)); + String detectorId = (String) response.get("_id"); + Thread.sleep(1_000); + return detectorId; + } + + @Override + protected void waitAllSyncheticDataIngested(int expectedSize, String datasetName, RestClient client) throws Exception { + int maxWaitCycles = 3; + do { + Request request = new Request("POST", String.format(Locale.ROOT, "/%s/_search", datasetName)); + request + .setJsonEntity( + String + .format( + Locale.ROOT, + "{\"query\": {" + + " \"match_all\": {}" + + " }," + + " \"size\": 1," + + " \"sort\": [" + + " {" + + " \"timestamp\": {" + + " \"order\": \"desc\"" + + " }" + + " }" + + " ]}" + ) + ); + // Make sure all of the test data has been ingested + // Expected response: + // "_index":"synthetic","_type":"_doc","_id":"10080","_score":null,"_source":{"timestamp":"2019-11-08T00:00:00Z","Feature1":156.30028000000001,"Feature2":100.211205,"host":"host1"},"sort":[1573171200000]} + Response response = client.performRequest(request); + JsonObject json = JsonParser + .parseReader(new InputStreamReader(response.getEntity().getContent(), Charset.defaultCharset())) + .getAsJsonObject(); + JsonArray hits = json.getAsJsonObject("hits").getAsJsonArray("hits"); + if (hits != null + && hits.size() == 1 + && expectedSize - 1 == hits.get(0).getAsJsonObject().getAsJsonPrimitive("_id").getAsLong()) { + break; + } else { + request = new Request("POST", String.format(Locale.ROOT, "/%s/_refresh", datasetName)); + client.performRequest(request); + } + Thread.sleep(1_000); + } while (maxWaitCycles-- >= 0); + } + + protected void setWarningHandler(Request request, boolean strictDeprecationMode) { + RequestOptions.Builder options = RequestOptions.DEFAULT.toBuilder(); + options.setWarningsHandler(strictDeprecationMode ? WarningsHandler.STRICT : WarningsHandler.PERMISSIVE); + request.setOptions(options.build()); + } +} diff --git a/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java-e b/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java-e new file mode 100644 index 000000000..8edab0d15 --- /dev/null +++ b/src/test/java/org/opensearch/ad/e2e/DetectionResultEvalutationIT.java-e @@ -0,0 +1,317 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.e2e; + +import static org.opensearch.timeseries.TestHelpers.toHttpEntity; + +import java.text.SimpleDateFormat; +import java.time.Clock; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Calendar; +import java.util.Date; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.TimeZone; +import java.util.concurrent.TimeUnit; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.core.Logger; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.client.RestClient; +import org.opensearch.common.xcontent.support.XContentMapValues; +import org.opensearch.timeseries.TestHelpers; + +import com.google.common.collect.ImmutableMap; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonPrimitive; + +public class DetectionResultEvalutationIT extends AbstractSyntheticDataTest { + protected static final Logger LOG = (Logger) LogManager.getLogger(DetectionResultEvalutationIT.class); + + /** + * Wait for HCAD cold start to finish. + * @param detectorId Detector Id + * @param data Data in Json format + * @param trainTestSplit Training data size + * @param shingleSize Shingle size + * @param intervalMinutes Detector Interval + * @param client OpenSearch Client + * @throws Exception when failing to query/indexing from/to OpenSearch + */ + private void waitForHCADStartDetector( + String detectorId, + List data, + int trainTestSplit, + int shingleSize, + int intervalMinutes, + RestClient client + ) throws Exception { + + long startTime = System.currentTimeMillis(); + long duration = 0; + do { + /* + * single stream detectors will throw exception if not finding models in the + * callback, while HCAD detectors will return early, record the exception in + * node state, and throw exception in the next run. HCAD did it this way since + * it does not know when current run is gonna finish (e.g, we may have millions + * of entities to process in one run). So for single-stream detector test case, + * we can check the exception to see if models are initialized or not. So HCAD, + * we have to either wait for next runs or use profile API. Here I chose profile + * API since it is faster. Will add these explanation in the comments. + */ + Thread.sleep(5_000); + String initProgress = profileDetectorInitProgress(detectorId, client); + if (initProgress.equals("100%")) { + break; + } + try { + profileDetectorInitProgress(detectorId, client); + } catch (Exception e) {} + duration = System.currentTimeMillis() - startTime; + } while (duration <= 60_000); + } + + public void testValidationIntervalRecommendation() throws Exception { + RestClient client = client(); + long recDetectorIntervalMillis = 180000; + long recDetectorIntervalMinutes = recDetectorIntervalMillis / 60000; + List data = createData(2000, recDetectorIntervalMillis); + indexTrainData("validation", data, 2000, client); + long detectorInterval = 1; + String requestBody = String + .format( + Locale.ROOT, + "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" + + ", \"indices\": [\"validation\"], \"feature_attributes\": [{ \"feature_name\": \"feature 1\", \"feature_enabled\": " + + "\"true\", \"aggregation_query\": { \"Feature1\": { \"sum\": { \"field\": \"Feature1\" } } } }, { \"feature_name\"" + + ": \"feature 2\", \"feature_enabled\": \"true\", \"aggregation_query\": { \"Feature2\": { \"sum\": { \"field\": " + + "\"Feature2\" } } } }], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }" + + ",\"window_delay\":{\"period\":{\"interval\":10,\"unit\":\"Minutes\"}}}", + detectorInterval + ); + Response resp = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/_validate/model", + ImmutableMap.of(), + toHttpEntity(requestBody), + null + ); + Map responseMap = entityAsMap(resp); + @SuppressWarnings("unchecked") + Map> messageMap = (Map>) XContentMapValues + .extractValue("model", responseMap); + assertEquals( + ADCommonMessages.DETECTOR_INTERVAL_REC + recDetectorIntervalMinutes, + messageMap.get("detection_interval").get("message") + ); + } + + public void testValidationWindowDelayRecommendation() throws Exception { + RestClient client = client(); + long recDetectorIntervalMillisForDataSet = 180000; + // this would be equivalent to the window delay in this data test + List data = createData(2000, recDetectorIntervalMillisForDataSet); + indexTrainData("validation", data, 2000, client); + long detectorInterval = 4; + long expectedWindowDelayMillis = Instant.now().toEpochMilli() - data.get(0).get("timestamp").getAsLong(); + // we always round up for window delay recommendation to reduce chance of missed data. + long expectedWindowDelayMinutes = (long) Math.ceil(expectedWindowDelayMillis / 60000.0); + String requestBody = String + .format( + Locale.ROOT, + "{ \"name\": \"test\", \"description\": \"test\", \"time_field\": \"timestamp\"" + + ", \"indices\": [\"validation\"], \"feature_attributes\": [{ \"feature_name\": \"feature 1\", \"feature_enabled\": " + + "\"true\", \"aggregation_query\": { \"Feature1\": { \"sum\": { \"field\": \"Feature1\" } } } }, { \"feature_name\"" + + ": \"feature 2\", \"feature_enabled\": \"true\", \"aggregation_query\": { \"Feature2\": { \"sum\": { \"field\": " + + "\"Feature2\" } } } }], \"detection_interval\": { \"period\": { \"interval\": %d, \"unit\": \"Minutes\" } }" + + ",\"window_delay\":{\"period\":{\"interval\":1,\"unit\":\"Minutes\"}}}", + detectorInterval + ); + Response resp = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/_validate/model", + ImmutableMap.of(), + toHttpEntity(requestBody), + null + ); + Map responseMap = entityAsMap(resp); + @SuppressWarnings("unchecked") + Map> messageMap = (Map>) XContentMapValues + .extractValue("model", responseMap); + assertEquals( + String.format(Locale.ROOT, ADCommonMessages.WINDOW_DELAY_REC, expectedWindowDelayMinutes, expectedWindowDelayMinutes), + messageMap.get("window_delay").get("message") + ); + } + + private List createData(int numOfDataPoints, long detectorIntervalMS) { + List list = new ArrayList<>(); + for (int i = 1; i < numOfDataPoints; i++) { + long valueFeature1 = randomLongBetween(1, 10000000); + long valueFeature2 = randomLongBetween(1, 10000000); + JsonObject obj = new JsonObject(); + JsonElement element = new JsonPrimitive(Instant.now().toEpochMilli() - (detectorIntervalMS * i)); + obj.add("timestamp", element); + obj.add("Feature1", new JsonPrimitive(valueFeature1)); + obj.add("Feature2", new JsonPrimitive(valueFeature2)); + list.add(obj); + } + return list; + } + + private void indexTrainData(String datasetName, List data, int trainTestSplit, RestClient client) throws Exception { + Request request = new Request("PUT", datasetName); + String requestBody = "{ \"mappings\": { \"properties\": { \"timestamp\": { \"type\": \"date\"}," + + " \"Feature1\": { \"type\": \"long\" }, \"Feature2\": { \"type\": \"long\" } } } }"; + request.setJsonEntity(requestBody); + // a WarningFailureException on access system indices .opendistro_security will fail the test if this is not false. + setWarningHandler(request, false); + client.performRequest(request); + Thread.sleep(1_000); + data.stream().limit(trainTestSplit).forEach(r -> { + try { + Request req = new Request("POST", String.format(Locale.ROOT, "/%s/_doc/", datasetName)); + req.setJsonEntity(r.toString()); + client.performRequest(req); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + Thread.sleep(3_000); + } + + public void testRestartHCADDetector() throws Exception { + // TODO: this test case will run for a much longer time and timeout with security enabled + if (!isHttps()) { + try { + disableResourceNotFoundFaultTolerence(); + verifyRestart("synthetic", 1, 8); + } catch (Throwable throwable) { + LOG.info("Retry restart test case", throwable); + cleanUpCluster(); + wipeAllODFEIndices(); + fail(); + } + } + } + + private void verifyRestart(String datasetName, int intervalMinutes, int shingleSize) throws Exception { + RestClient client = client(); + + String dataFileName = String.format(Locale.ROOT, "data/%s.data", datasetName); + + List data = getData(dataFileName); + + String categoricalField = "host"; + String tsField = "timestamp"; + + Clock clock = Clock.systemUTC(); + long currentMilli = clock.millis(); + int trainTestSplit = 1500; + + // e.g., 2019-11-01T00:03:00Z + String pattern = "yyyy-MM-dd'T'HH:mm:ss'Z'"; + SimpleDateFormat simpleDateFormat = new SimpleDateFormat(pattern, Locale.ROOT); + simpleDateFormat.setTimeZone(TimeZone.getTimeZone("UTC")); + // calculate the gap between current time and the beginning of last shingle + // the gap is used to adjust input training data's time so that the last + // few items of training data maps to current time. We need this adjustment + // because CompositeRetriever will compare expiry time with current time in hasNext + // method. The expiry time is calculated using request (one parameter of the run API) + // end time plus some fraction of interval. If the expiry time is less than + // current time, CompositeRetriever thinks this request expires and refuses to start + // querying. So this adjustment is to make the following simulateHCADStartDetector work. + String lastTrainShingleStartTime = data.get(trainTestSplit - shingleSize).getAsJsonPrimitive(tsField).getAsString(); + Date date = simpleDateFormat.parse(lastTrainShingleStartTime); + long diff = currentMilli - date.getTime(); + TimeUnit time = TimeUnit.MINUTES; + // by the time we trigger the run API, a few seconds have passed. +5 to make the adjusted time more than current time. + long gap = time.convert(diff, TimeUnit.MILLISECONDS) + 5; + + Calendar c = Calendar.getInstance(TimeZone.getTimeZone("UTC"), Locale.ROOT); + + // only change training data as we only need to make sure detector is fully initialized + for (int i = 0; i < trainTestSplit; i++) { + JsonObject row = data.get(i); + // add categorical field since the original data is for single-stream detectors + row.addProperty(categoricalField, "host1"); + + String dateString = row.getAsJsonPrimitive(tsField).getAsString(); + date = simpleDateFormat.parse(dateString); + c.setTime(date); + c.add(Calendar.MINUTE, (int) gap); + String adjustedDate = simpleDateFormat.format(c.getTime()); + row.addProperty(tsField, adjustedDate); + } + + bulkIndexTrainData(datasetName, data, trainTestSplit, client, categoricalField); + + String detectorId = createDetector(datasetName, intervalMinutes, client, categoricalField, 0); + // cannot stop without actually starting detector because ad complains no ad job index + startDetector(detectorId, client); + profileDetectorInitProgress(detectorId, client); + // it would be long if we wait for the job actually run the work periodically; speed it up by using simulateHCADStartDetector + waitForHCADStartDetector(detectorId, data, trainTestSplit, shingleSize, intervalMinutes, client); + String initProgress = profileDetectorInitProgress(detectorId, client); + assertEquals("init progress is " + initProgress, "100%", initProgress); + stopDetector(detectorId, client); + // restart detector + startDetector(detectorId, client); + waitForHCADStartDetector(detectorId, data, trainTestSplit, shingleSize, intervalMinutes, client); + initProgress = profileDetectorInitProgress(detectorId, client); + assertEquals("init progress is " + initProgress, "100%", initProgress); + } + + private void stopDetector(String detectorId, RestClient client) throws Exception { + Request request = new Request("POST", String.format(Locale.ROOT, "/_plugins/_anomaly_detection/detectors/%s/_stop", detectorId)); + + Map response = entityAsMap(client.performRequest(request)); + String responseDetectorId = (String) response.get("_id"); + assertEquals(detectorId, responseDetectorId); + } + + private void startDetector(String detectorId, RestClient client) throws Exception { + Request request = new Request("POST", String.format(Locale.ROOT, "/_plugins/_anomaly_detection/detectors/%s/_start", detectorId)); + + Map response = entityAsMap(client.performRequest(request)); + String responseDetectorId = (String) response.get("_id"); + assertEquals(detectorId, responseDetectorId); + } + + private String profileDetectorInitProgress(String detectorId, RestClient client) throws Exception { + Request request = new Request( + "GET", + String.format(Locale.ROOT, "/_plugins/_anomaly_detection/detectors/%s/_profile/init_progress", detectorId) + ); + + Map response = entityAsMap(client.performRequest(request)); + /* + * Example response: + * { + * "init_progress": { + * "percentage": "100%" + * } + * } + */ + return (String) ((Map) response.get("init_progress")).get("percentage"); + } +} diff --git a/src/test/java/org/opensearch/ad/e2e/SingleStreamModelPerfIT.java-e b/src/test/java/org/opensearch/ad/e2e/SingleStreamModelPerfIT.java-e new file mode 100644 index 000000000..04f959442 --- /dev/null +++ b/src/test/java/org/opensearch/ad/e2e/SingleStreamModelPerfIT.java-e @@ -0,0 +1,230 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.e2e; + +import static org.opensearch.timeseries.TestHelpers.toHttpEntity; + +import java.io.File; +import java.io.FileReader; +import java.nio.charset.Charset; +import java.time.Instant; +import java.time.format.DateTimeFormatter; +import java.time.temporal.ChronoUnit; +import java.util.AbstractMap.SimpleEntry; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; + +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.core.Logger; +import org.opensearch.client.RestClient; +import org.opensearch.timeseries.TestHelpers; + +import com.google.common.collect.ImmutableList; +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; + +public class SingleStreamModelPerfIT extends AbstractSyntheticDataTest { + protected static final Logger LOG = (Logger) LogManager.getLogger(SingleStreamModelPerfIT.class); + + public void testDataset() throws Exception { + // TODO: this test case will run for a much longer time and timeout with security enabled + if (!isHttps()) { + disableResourceNotFoundFaultTolerence(); + verifyAnomaly("synthetic", 1, 1500, 8, .4, .9, 10); + } + } + + private void verifyAnomaly( + String datasetName, + int intervalMinutes, + int trainTestSplit, + int shingleSize, + double minPrecision, + double minRecall, + double maxError + ) throws Exception { + RestClient client = client(); + + String dataFileName = String.format(Locale.ROOT, "data/%s.data", datasetName); + String labelFileName = String.format(Locale.ROOT, "data/%s.label", datasetName); + + List data = getData(dataFileName); + List> anomalies = getAnomalyWindows(labelFileName); + + bulkIndexTrainData(datasetName, data, trainTestSplit, client, null); + // single-stream detector can use window delay 0 here because we give the run api the actual data time + String detectorId = createDetector(datasetName, intervalMinutes, client, null, 0); + simulateSingleStreamStartDetector(detectorId, data, trainTestSplit, shingleSize, intervalMinutes, client); + bulkIndexTestData(data, datasetName, trainTestSplit, client); + double[] testResults = getTestResults(detectorId, data, trainTestSplit, intervalMinutes, anomalies, client); + verifyTestResults(testResults, anomalies, minPrecision, minRecall, maxError); + } + + private void verifyTestResults( + double[] testResults, + List> anomalies, + double minPrecision, + double minRecall, + double maxError + ) { + + double positives = testResults[0]; + double truePositives = testResults[1]; + double positiveAnomalies = testResults[2]; + double errors = testResults[3]; + + // precision = predicted anomaly points that are true / predicted anomaly points + double precision = positives > 0 ? truePositives / positives : 1; + assertTrue(precision >= minPrecision); + + // recall = windows containing predicted anomaly points / total anomaly windows + double recall = anomalies.size() > 0 ? positiveAnomalies / anomalies.size() : 1; + assertTrue(recall >= minRecall); + + assertTrue(errors <= maxError); + LOG.info("Precision: {}, Window recall: {}", precision, recall); + } + + private double[] getTestResults( + String detectorId, + List data, + int trainTestSplit, + int intervalMinutes, + List> anomalies, + RestClient client + ) throws Exception { + + double positives = 0; + double truePositives = 0; + Set positiveAnomalies = new HashSet<>(); + double errors = 0; + for (int i = trainTestSplit; i < data.size(); i++) { + Instant begin = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(data.get(i).get("timestamp").getAsString())); + Instant end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); + try { + Map response = getDetectionResult(detectorId, begin, end, client); + double anomalyGrade = (double) response.get("anomalyGrade"); + if (anomalyGrade > 0) { + positives++; + int result = isAnomaly(begin, anomalies); + if (result != -1) { + truePositives++; + positiveAnomalies.add(result); + } + } + } catch (Exception e) { + errors++; + logger.error("failed to get detection results", e); + } + } + return new double[] { positives, truePositives, positiveAnomalies.size(), errors }; + } + + private List> getAnomalyWindows(String labalFileName) throws Exception { + JsonArray windows = JsonParser + .parseReader(new FileReader(new File(getClass().getResource(labalFileName).toURI()), Charset.defaultCharset())) + .getAsJsonArray(); + List> anomalies = new ArrayList<>(windows.size()); + for (int i = 0; i < windows.size(); i++) { + JsonArray window = windows.get(i).getAsJsonArray(); + Instant begin = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(window.get(0).getAsString())); + Instant end = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(window.get(1).getAsString())); + anomalies.add(new SimpleEntry<>(begin, end)); + } + return anomalies; + } + + /** + * Simulate starting detector without waiting for job scheduler to run. Our build process is already very slow (takes 10 mins+) + * to finish integration tests. This method triggers run API to simulate job scheduler execution in a fast-paced way. + * @param detectorId Detector Id + * @param data Data in Json format + * @param trainTestSplit Training data size + * @param shingleSize Shingle size + * @param intervalMinutes Detector Interval + * @param client OpenSearch Client + * @throws Exception when failing to query/indexing from/to OpenSearch + */ + private void simulateSingleStreamStartDetector( + String detectorId, + List data, + int trainTestSplit, + int shingleSize, + int intervalMinutes, + RestClient client + ) throws Exception { + + Instant trainTime = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(data.get(trainTestSplit - 1).get("timestamp").getAsString())); + + Instant begin = null; + Instant end = null; + for (int i = 0; i < shingleSize; i++) { + begin = trainTime.minus(intervalMinutes * (shingleSize - 1 - i), ChronoUnit.MINUTES); + end = begin.plus(intervalMinutes, ChronoUnit.MINUTES); + try { + getDetectionResult(detectorId, begin, end, client); + } catch (Exception e) {} + } + // It takes time to wait for model initialization + long startTime = System.currentTimeMillis(); + do { + try { + Thread.sleep(5_000); + getDetectionResult(detectorId, begin, end, client); + break; + } catch (Exception e) { + long duration = System.currentTimeMillis() - startTime; + // we wait at most 60 secs + if (duration > 60_000) { + throw new RuntimeException(e); + } + } + } while (true); + } + + private void bulkIndexTestData(List data, String datasetName, int trainTestSplit, RestClient client) throws Exception { + StringBuilder bulkRequestBuilder = new StringBuilder(); + for (int i = trainTestSplit; i < data.size(); i++) { + bulkRequestBuilder.append("{ \"index\" : { \"_index\" : \"" + datasetName + "\", \"_id\" : \"" + i + "\" } }\n"); + bulkRequestBuilder.append(data.get(i).toString()).append("\n"); + } + TestHelpers + .makeRequest( + client, + "POST", + "_bulk?refresh=true", + null, + toHttpEntity(bulkRequestBuilder.toString()), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + Thread.sleep(1_000); + waitAllSyncheticDataIngested(data.size(), datasetName, client); + } + + private int isAnomaly(Instant time, List> labels) { + for (int i = 0; i < labels.size(); i++) { + Entry window = labels.get(i); + if (time.compareTo(window.getKey()) >= 0 && time.compareTo(window.getValue()) <= 0) { + return i; + } + } + return -1; + } +} diff --git a/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java b/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java index 0051154bc..b5ce70d05 100644 --- a/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java +++ b/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java @@ -56,10 +56,10 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.ActionListener; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.util.ArrayEqMatcher; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.dataprocessor.Imputer; import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; @@ -128,7 +128,7 @@ public void setup() { ExecutorService executorService = mock(ExecutorService.class); - when(threadPool.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); + when(threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); doAnswer(invocation -> { Runnable runnable = invocation.getArgument(0); runnable.run(); @@ -150,7 +150,7 @@ public void setup() { maxPreviewSamples, featureBufferTtl, threadPool, - AnomalyDetectorPlugin.AD_THREAD_POOL_NAME + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME ) ); } @@ -226,7 +226,7 @@ public void getColdStartData_returnExpectedToListener( maxPreviewSamples, featureBufferTtl, threadPool, - AnomalyDetectorPlugin.AD_THREAD_POOL_NAME + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME ) ); featureManager.getColdStartData(detector, listener); diff --git a/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java-e b/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java-e new file mode 100644 index 000000000..b5ce70d05 --- /dev/null +++ b/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java-e @@ -0,0 +1,1093 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.feature; + +import static java.util.Arrays.asList; +import static java.util.Optional.empty; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Matchers.argThat; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE; + +import java.io.IOException; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.AbstractMap.SimpleEntry; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Deque; +import java.util.List; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; + +import junitparams.JUnitParamsRunner; +import junitparams.Parameters; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.util.ArrayEqMatcher; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.dataprocessor.Imputer; +import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; + +@RunWith(JUnitParamsRunner.class) +@SuppressWarnings("unchecked") +public class FeatureManagerTests { + + // configuration + private int maxTrainSamples; + private int maxSampleStride; + private int trainSampleTimeRangeInHours; + private int minTrainSamples; + private int shingleSize; + private double maxMissingPointsRate; + private int maxNeighborDistance; + private double previewSampleRate; + private int maxPreviewSamples; + private Duration featureBufferTtl; + private long intervalInMilliseconds; + + @Mock + private AnomalyDetector detector; + + @Mock + private SearchFeatureDao searchFeatureDao; + + @Mock + private Imputer imputer; + + @Mock + private Clock clock; + + @Mock + private ThreadPool threadPool; + + private FeatureManager featureManager; + + private String detectorId; + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + + maxTrainSamples = 24; + maxSampleStride = 100; + trainSampleTimeRangeInHours = 1; + minTrainSamples = 4; + shingleSize = 3; + maxMissingPointsRate = 0.67; + maxNeighborDistance = 2; + previewSampleRate = 0.5; + maxPreviewSamples = 2; + featureBufferTtl = Duration.ofMillis(1_000L); + + detectorId = "id"; + when(detector.getId()).thenReturn(detectorId); + when(detector.getShingleSize()).thenReturn(shingleSize); + IntervalTimeConfiguration detectorIntervalTimeConfig = new IntervalTimeConfiguration(1, ChronoUnit.MINUTES); + intervalInMilliseconds = detectorIntervalTimeConfig.toDuration().toMillis(); + when(detector.getIntervalInMilliseconds()).thenReturn(intervalInMilliseconds); + + Imputer imputer = new LinearUniformImputer(false); + + ExecutorService executorService = mock(ExecutorService.class); + + when(threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(executorService).execute(any(Runnable.class)); + + this.featureManager = spy( + new FeatureManager( + searchFeatureDao, + imputer, + clock, + maxTrainSamples, + maxSampleStride, + trainSampleTimeRangeInHours, + minTrainSamples, + maxMissingPointsRate, + maxNeighborDistance, + previewSampleRate, + maxPreviewSamples, + featureBufferTtl, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME + ) + ); + } + + private Object[] getColdStartDataTestData() { + double[][] samples = new double[][] { { 1.0 } }; + return new Object[] { + new Object[] { 1L, new SimpleEntry<>(samples, 1), 1, samples }, + new Object[] { 1L, null, 1, null }, + new Object[] { null, new SimpleEntry<>(samples, 1), 1, null }, + new Object[] { null, null, 1, null }, }; + } + + private Object[] getTrainDataTestData() { + List> ranges = asList( + entry(0L, 900_000L), + entry(900_000L, 1_800_000L), + entry(1_800_000L, 2_700_000L), + entry(2_700_000L, 3_600_000L) + ); + return new Object[] { + new Object[] { 3_600_000L, ranges, asList(ar(1), ar(2), ar(3), ar(4)), new double[][] { { 1, 2, 3, 4 } } }, + new Object[] { 3_600_000L, ranges, asList(ar(), ar(2), ar(3), ar(4)), new double[][] { { 2, 2, 3, 4 } } }, + new Object[] { 3_600_000L, ranges, asList(ar(1), ar(), ar(3), ar(4)), new double[][] { { 1, 3, 3, 4 } } }, + new Object[] { 3_600_000L, ranges, asList(ar(1), ar(2), ar(), ar(4)), new double[][] { { 1, 2, 4, 4 } } }, + new Object[] { 3_600_000L, ranges, asList(ar(1), ar(), ar(), ar(4)), new double[][] { { 1, 1, 4, 4 } } }, + new Object[] { 3_600_000L, ranges, asList(ar(), ar(2), ar(), ar(4)), new double[][] { { 2, 2, 4, 4 } } }, + new Object[] { 3_600_000L, ranges, asList(ar(), ar(), ar(3), ar(4)), null }, + new Object[] { 3_600_000L, ranges, asList(ar(1), empty(), empty(), empty()), null }, + new Object[] { 3_600_000L, ranges, asList(empty(), empty(), empty(), ar(4)), null }, + new Object[] { 3_600_000L, ranges, asList(empty(), empty(), empty(), empty()), null }, + new Object[] { null, null, null, null } }; + } + + @Test + @SuppressWarnings("unchecked") + @Parameters(method = "getTrainDataTestData") + public void getColdStartData_returnExpectedToListener( + Long latestTime, + List> sampleRanges, + List> samples, + double[][] expected + ) throws Exception { + long detectionInterval = (new IntervalTimeConfiguration(15, ChronoUnit.MINUTES)).toDuration().toMillis(); + when(detector.getIntervalInMilliseconds()).thenReturn(detectionInterval); + when(detector.getShingleSize()).thenReturn(4); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.ofNullable(latestTime)); + return null; + }).when(searchFeatureDao).getLatestDataTime(eq(detector), any(ActionListener.class)); + if (latestTime != null) { + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(samples); + return null; + }).when(searchFeatureDao).getFeatureSamplesForPeriods(eq(detector), eq(sampleRanges), any(ActionListener.class)); + } + + ActionListener> listener = mock(ActionListener.class); + featureManager = spy( + new FeatureManager( + searchFeatureDao, + imputer, + clock, + maxTrainSamples, + maxSampleStride, + trainSampleTimeRangeInHours, + minTrainSamples, + 0.5, /*maxMissingPointsRate*/ + 1, /*maxNeighborDistance*/ + previewSampleRate, + maxPreviewSamples, + featureBufferTtl, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME + ) + ); + featureManager.getColdStartData(detector, listener); + + ArgumentCaptor> captor = ArgumentCaptor.forClass(Optional.class); + verify(listener).onResponse(captor.capture()); + Optional result = captor.getValue(); + assertTrue(Arrays.deepEquals(expected, result.orElse(null))); + } + + @Test + @SuppressWarnings("unchecked") + public void getColdStartData_throwToListener_whenSearchFail() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException()); + return null; + }).when(searchFeatureDao).getLatestDataTime(eq(detector), any(ActionListener.class)); + + ActionListener> listener = mock(ActionListener.class); + featureManager.getColdStartData(detector, listener); + + verify(listener).onFailure(any(Exception.class)); + } + + @Test + @SuppressWarnings("unchecked") + public void getColdStartData_throwToListener_onQueryCreationError() throws Exception { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.ofNullable(0L)); + return null; + }).when(searchFeatureDao).getLatestDataTime(eq(detector), any(ActionListener.class)); + doThrow(IOException.class).when(searchFeatureDao).getFeatureSamplesForPeriods(eq(detector), any(), any(ActionListener.class)); + + ActionListener> listener = mock(ActionListener.class); + featureManager.getColdStartData(detector, listener); + + verify(listener).onFailure(any(EndRunException.class)); + } + + private Object[] batchShingleData() { + return new Object[] { + new Object[] { new double[][] { { 1.0 } }, 1, new double[][] { { 1.0 } } }, + new Object[] { new double[][] { { 1.0, 2.0 } }, 1, new double[][] { { 1.0, 2.0 } } }, + new Object[] { new double[][] { { 1.0 }, { 2, 0 }, { 3.0 } }, 1, new double[][] { { 1.0 }, { 2.0 }, { 3.0 } } }, + new Object[] { new double[][] { { 1.0 }, { 2, 0 }, { 3.0 } }, 2, new double[][] { { 1.0, 2.0 }, { 2.0, 3.0 } } }, + new Object[] { new double[][] { { 1.0 }, { 2, 0 }, { 3.0 } }, 3, new double[][] { { 1.0, 2.0, 3.0 } } }, + new Object[] { new double[][] { { 1.0, 2.0 }, { 3.0, 4.0 } }, 1, new double[][] { { 1.0, 2.0 }, { 3.0, 4.0 } } }, + new Object[] { new double[][] { { 1.0, 2.0 }, { 3.0, 4.0 } }, 2, new double[][] { { 1.0, 2.0, 3.0, 4.0 } } }, + new Object[] { + new double[][] { { 1.0, 2.0 }, { 3.0, 4.0 }, { 5.0, 6.0 } }, + 3, + new double[][] { { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 } } } }; + }; + + @Test + @Parameters(method = "batchShingleData") + public void batchShingle_returnExpected(double[][] points, int shingleSize, double[][] expected) { + assertTrue(Arrays.deepEquals(expected, featureManager.batchShingle(points, shingleSize))); + } + + private Object[] batchShingleIllegalArgumentData() { + return new Object[] { + new Object[] { new double[][] { { 1.0 } }, 0 }, + new Object[] { new double[][] { { 1.0 } }, 2 }, + new Object[] { new double[][] { { 1.0, 2.0 } }, 0 }, + new Object[] { new double[][] { { 1.0, 2.0 } }, 2 }, + new Object[] { new double[][] { { 1.0 }, { 2.0 } }, 0 }, + new Object[] { new double[][] { { 1.0 }, { 2.0 } }, 3 }, + new Object[] { new double[][] {}, 0 }, + new Object[] { new double[][] {}, 1 }, + new Object[] { new double[][] { {}, {} }, 0 }, + new Object[] { new double[][] { {}, {} }, 1 }, + new Object[] { new double[][] { {}, {} }, 2 }, + new Object[] { new double[][] { {}, {} }, 3 }, }; + }; + + @Test(expected = IllegalArgumentException.class) + @Parameters(method = "batchShingleIllegalArgumentData") + public void batchShingle_throwExpected_forInvalidInput(double[][] points, int shingleSize) { + featureManager.batchShingle(points, shingleSize); + } + + @Test + public void clear_deleteFeatures() throws IOException { + long start = shingleSize * intervalInMilliseconds; + long end = (shingleSize + 1) * intervalInMilliseconds; + + AtomicBoolean firstQuery = new AtomicBoolean(true); + + doAnswer(invocation -> { + ActionListener>> daoListener = invocation.getArgument(2); + if (firstQuery.get()) { + firstQuery.set(false); + daoListener + .onResponse(asList(Optional.of(new double[] { 3 }), Optional.of(new double[] { 2 }), Optional.of(new double[] { 1 }))); + } else { + daoListener.onResponse(asList(Optional.ofNullable(null), Optional.ofNullable(null), Optional.of(new double[] { 1 }))); + } + return null; + }).when(searchFeatureDao).getFeatureSamplesForPeriods(eq(detector), any(List.class), any(ActionListener.class)); + featureManager.getCurrentFeatures(detector, start, end, mock(ActionListener.class)); + + SinglePointFeatures beforeMaintenance = getCurrentFeatures(detector, start, end); + assertTrue(beforeMaintenance.getUnprocessedFeatures().isPresent()); + assertTrue(beforeMaintenance.getProcessedFeatures().isPresent()); + + featureManager.clear(detector.getId()); + + SinglePointFeatures afterMaintenance = getCurrentFeatures(detector, start, end); + assertTrue(afterMaintenance.getUnprocessedFeatures().isPresent()); + assertFalse(afterMaintenance.getProcessedFeatures().isPresent()); + } + + private SinglePointFeatures getCurrentFeatures(AnomalyDetector detector, long start, long end) throws IOException { + ActionListener listener = mock(ActionListener.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(SinglePointFeatures.class); + featureManager.getCurrentFeatures(detector, start, end, listener); + verify(listener).onResponse(captor.capture()); + return captor.getValue(); + } + + @Test + public void maintenance_removeStaleData() throws IOException { + long start = shingleSize * intervalInMilliseconds; + long end = (shingleSize + 1) * intervalInMilliseconds; + + AtomicBoolean firstQuery = new AtomicBoolean(true); + + doAnswer(invocation -> { + ActionListener>> daoListener = invocation.getArgument(2); + if (firstQuery.get()) { + firstQuery.set(false); + daoListener + .onResponse(asList(Optional.of(new double[] { 3 }), Optional.of(new double[] { 2 }), Optional.of(new double[] { 1 }))); + } else { + daoListener.onResponse(asList(Optional.ofNullable(null), Optional.ofNullable(null), Optional.of(new double[] { 1 }))); + } + return null; + }).when(searchFeatureDao).getFeatureSamplesForPeriods(eq(detector), any(List.class), any(ActionListener.class)); + featureManager.getCurrentFeatures(detector, start, end, mock(ActionListener.class)); + + SinglePointFeatures beforeMaintenance = getCurrentFeatures(detector, start, end); + assertTrue(beforeMaintenance.getUnprocessedFeatures().isPresent()); + assertTrue(beforeMaintenance.getProcessedFeatures().isPresent()); + when(clock.instant()).thenReturn(Instant.ofEpochMilli(end + 1).plus(featureBufferTtl)); + + featureManager.maintenance(); + + SinglePointFeatures afterMaintenance = getCurrentFeatures(detector, start, end); + assertTrue(afterMaintenance.getUnprocessedFeatures().isPresent()); + assertFalse(afterMaintenance.getProcessedFeatures().isPresent()); + } + + @Test + public void maintenance_keepRecentData() throws IOException { + long start = shingleSize * intervalInMilliseconds; + long end = (shingleSize + 1) * intervalInMilliseconds; + + AtomicBoolean firstQuery = new AtomicBoolean(true); + + doAnswer(invocation -> { + ActionListener>> daoListener = invocation.getArgument(2); + if (firstQuery.get()) { + firstQuery.set(false); + daoListener + .onResponse(asList(Optional.of(new double[] { 3 }), Optional.of(new double[] { 2 }), Optional.of(new double[] { 1 }))); + } else { + daoListener.onResponse(asList(Optional.ofNullable(null), Optional.ofNullable(null), Optional.of(new double[] { 1 }))); + } + return null; + }).when(searchFeatureDao).getFeatureSamplesForPeriods(eq(detector), any(List.class), any(ActionListener.class)); + featureManager.getCurrentFeatures(detector, start, end, mock(ActionListener.class)); + + SinglePointFeatures beforeMaintenance = getCurrentFeatures(detector, start, end); + assertTrue(beforeMaintenance.getUnprocessedFeatures().isPresent()); + assertTrue(beforeMaintenance.getProcessedFeatures().isPresent()); + when(clock.instant()).thenReturn(Instant.ofEpochMilli(end)); + + featureManager.maintenance(); + + SinglePointFeatures afterMaintenance = getCurrentFeatures(detector, start, end); + assertTrue(afterMaintenance.getUnprocessedFeatures().isPresent()); + assertTrue(afterMaintenance.getProcessedFeatures().isPresent()); + } + + @Test + public void maintenance_doNotThrowException() { + when(clock.instant()).thenThrow(new RuntimeException()); + + featureManager.maintenance(); + } + + @SuppressWarnings("unchecked") + private void getPreviewFeaturesTemplate(List> samplesResults, boolean querySuccess, boolean previewSuccess) + throws IOException { + long start = 0L; + long end = 240_000L; + long detectionInterval = (new IntervalTimeConfiguration(1, ChronoUnit.MINUTES)).toDuration().toMillis(); + when(detector.getIntervalInMilliseconds()).thenReturn(detectionInterval); + + List> sampleRanges = Arrays.asList(new SimpleEntry<>(0L, 60_000L), new SimpleEntry<>(120_000L, 180_000L)); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + + ActionListener>> listener = null; + + if (args[2] instanceof ActionListener) { + listener = (ActionListener>>) args[2]; + } + + if (querySuccess) { + listener.onResponse(samplesResults); + } else { + listener.onFailure(new RuntimeException()); + } + + return null; + }).when(searchFeatureDao).getFeatureSamplesForPeriods(eq(detector), eq(sampleRanges), any()); + + when(imputer.impute(argThat(new ArrayEqMatcher<>(new double[][] { { 1, 3 } })), eq(3))).thenReturn(new double[][] { { 1, 2, 3 } }); + when(imputer.impute(argThat(new ArrayEqMatcher<>(new double[][] { { 0, 120000 } })), eq(3))) + .thenReturn(new double[][] { { 0, 60000, 120000 } }); + when(imputer.impute(argThat(new ArrayEqMatcher<>(new double[][] { { 60000, 180000 } })), eq(3))) + .thenReturn(new double[][] { { 60000, 120000, 180000 } }); + + ActionListener listener = mock(ActionListener.class); + featureManager.getPreviewFeatures(detector, start, end, listener); + + if (previewSuccess) { + Features expected = new Features( + asList(new SimpleEntry<>(120_000L, 180_000L)), + new double[][] { { 3 } }, + new double[][] { { 1, 2, 3 } } + ); + verify(listener).onResponse(expected); + } else { + verify(listener).onFailure(any(Exception.class)); + } + + } + + @Test + public void getPreviewFeatures_returnExpectedToListener() throws IOException { + getPreviewFeaturesTemplate(asList(Optional.of(new double[] { 1 }), Optional.of(new double[] { 3 })), true, true); + } + + @Test + public void getPreviewFeatures_returnExceptionToListener_whenNoDataToPreview() throws IOException { + getPreviewFeaturesTemplate(asList(), true, false); + } + + @Test + public void getPreviewFeatures_returnExceptionToListener_whenQueryFail() throws IOException { + getPreviewFeaturesTemplate(asList(Optional.of(new double[] { 1 }), Optional.of(new double[] { 3 })), false, false); + } + + @Test + public void getPreviewFeatureForEntity() throws IOException { + long start = 0L; + long end = 240_000L; + Entity entity = Entity.createSingleAttributeEntity("fieldName", "value"); + + List> coldStartSamples = new ArrayList<>(); + coldStartSamples.add(Optional.of(new double[] { 10.0 })); + coldStartSamples.add(Optional.of(new double[] { 30.0 })); + + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(4); + listener.onResponse(coldStartSamples); + return null; + }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); + + ActionListener listener = mock(ActionListener.class); + + featureManager.getPreviewFeaturesForEntity(detector, entity, start, end, listener); + + Features expected = new Features( + asList(new SimpleEntry<>(120_000L, 180_000L)), + new double[][] { { 30 } }, + new double[][] { { 10, 20, 30 } } + ); + verify(listener).onResponse(expected); + } + + @Test + public void getPreviewFeatureForEntity_noDataToPreview() throws IOException { + long start = 0L; + long end = 240_000L; + Entity entity = Entity.createSingleAttributeEntity("fieldName", "value"); + + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(4); + listener.onResponse(new ArrayList<>()); + return null; + }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); + + ActionListener listener = mock(ActionListener.class); + + featureManager.getPreviewFeaturesForEntity(detector, entity, start, end, listener); + + verify(listener).onFailure(any(IllegalArgumentException.class)); + } + + @Test + public void getPreviewEntities() { + long start = 0L; + long end = 240_000L; + + Entity entity1 = Entity.createSingleAttributeEntity("fieldName", "value1"); + Entity entity2 = Entity.createSingleAttributeEntity("fieldName", "value2"); + List entities = asList(entity1, entity2); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener.onResponse(entities); + return null; + }).when(searchFeatureDao).getHighestCountEntities(any(), anyLong(), anyLong(), any()); + + ActionListener> listener = mock(ActionListener.class); + + featureManager.getPreviewEntities(detector, start, end, listener); + + verify(listener).onResponse(entities); + } + + private void setupSearchFeatureDaoForGetCurrentFeatures( + List> preQueryResponse, + Optional>> testQueryResponse + ) throws IOException { + AtomicBoolean isPreQuery = new AtomicBoolean(true); + + doAnswer(invocation -> { + ActionListener>> daoListener = invocation.getArgument(2); + if (isPreQuery.get()) { + isPreQuery.set(false); + daoListener.onResponse(preQueryResponse); + } else { + if (testQueryResponse.isPresent()) { + daoListener.onResponse(testQueryResponse.get()); + } else { + daoListener.onFailure(new IOException()); + } + } + return null; + }).when(searchFeatureDao).getFeatureSamplesForPeriods(eq(detector), any(List.class), any(ActionListener.class)); + } + + private Object[] getCurrentFeaturesTestData_whenAfterQueryResultsFormFullShingle() { + return new Object[] { + new Object[] { + asList(Optional.empty(), Optional.empty(), Optional.empty()), + 3, + Optional.of(asList(Optional.of(new double[] { 1 }), Optional.of(new double[] { 2 }), Optional.of(new double[] { 3 }))), + new double[] { 1, 2, 3 } }, + new Object[] { + asList(Optional.empty(), Optional.of(new double[] { 1 }), Optional.of(new double[] { 5 })), + 1, + Optional.of(asList(Optional.of(new double[] { 3 }))), + new double[] { 1, 5, 3 } }, + new Object[] { + asList(Optional.empty(), Optional.empty(), Optional.of(new double[] { 1, 2 })), + 2, + Optional.of(asList(Optional.of(new double[] { 3, 4 }), Optional.of(new double[] { 5, 6 }))), + new double[] { 1, 2, 3, 4, 5, 6 } }, }; + } + + @Test + @Parameters(method = "getCurrentFeaturesTestData_whenAfterQueryResultsFormFullShingle") + public void getCurrentFeatures_returnExpectedProcessedFeatures_whenAfterQueryResultsFormFullShingle( + List> preQueryResponse, + long intervalOffsetFromPreviousQuery, + Optional>> testQueryResponse, + double[] expectedProcessedFeatures + ) throws IOException { + int expectedNumQueriesToSearchFeatureDao = 2; + long previousStartTime = shingleSize * intervalInMilliseconds; + long previousEndTime = previousStartTime + intervalInMilliseconds; + long testStartTime = previousStartTime + intervalOffsetFromPreviousQuery * intervalInMilliseconds; + long testEndTime = testStartTime + intervalInMilliseconds; + + // Set up + setupSearchFeatureDaoForGetCurrentFeatures(preQueryResponse, testQueryResponse); + featureManager.getCurrentFeatures(detector, previousStartTime, previousEndTime, mock(ActionListener.class)); + + // Start test + SinglePointFeatures listenerResponse = getCurrentFeatures(detector, testStartTime, testEndTime); + verify(searchFeatureDao, times(expectedNumQueriesToSearchFeatureDao)) + .getFeatureSamplesForPeriods(eq(detector), any(List.class), any(ActionListener.class)); + assertTrue(listenerResponse.getUnprocessedFeatures().isPresent()); + assertTrue(listenerResponse.getProcessedFeatures().isPresent()); + + double[] actualProcessedFeatures = listenerResponse.getProcessedFeatures().get(); + for (int i = 0; i < expectedProcessedFeatures.length; i++) { + assertEquals(expectedProcessedFeatures[i], actualProcessedFeatures[i], 0); + } + } + + private Object[] getCurrentFeaturesTestData_whenNoQueryNeededToFormFullShingle() { + return new Object[] { + new Object[] { + asList(Optional.of(new double[] { 1 }), Optional.of(new double[] { 2 }), Optional.of(new double[] { 3 })), + new double[] { 1, 2, 3 } }, + new Object[] { + asList(Optional.of(new double[] { 1, 2 }), Optional.of(new double[] { 3, 4 }), Optional.of(new double[] { 5, 6 })), + new double[] { 1, 2, 3, 4, 5, 6 } } }; + } + + @Test + @Parameters(method = "getCurrentFeaturesTestData_whenNoQueryNeededToFormFullShingle") + public void getCurrentFeatures_returnExpectedProcessedFeatures_whenNoQueryNeededToFormFullShingle( + List> preQueryResponse, + double[] expectedProcessedFeatures + ) throws IOException { + int expectedNumQueriesToSearchFeatureDao = 1; + long start = shingleSize * intervalInMilliseconds; + long end = start + intervalInMilliseconds; + + // Set up + setupSearchFeatureDaoForGetCurrentFeatures(preQueryResponse, Optional.empty()); + featureManager.getCurrentFeatures(detector, start, end, mock(ActionListener.class)); + + // Start test + SinglePointFeatures listenerResponse = getCurrentFeatures(detector, start, end); + verify(searchFeatureDao, times(expectedNumQueriesToSearchFeatureDao)) + .getFeatureSamplesForPeriods(eq(detector), any(List.class), any(ActionListener.class)); + assertTrue(listenerResponse.getUnprocessedFeatures().isPresent()); + assertTrue(listenerResponse.getProcessedFeatures().isPresent()); + + double[] actualProcessedFeatures = listenerResponse.getProcessedFeatures().get(); + for (int i = 0; i < expectedProcessedFeatures.length; i++) { + assertEquals(expectedProcessedFeatures[i], actualProcessedFeatures[i], 0); + } + } + + private Object[] getCurrentFeaturesTestData_whenAfterQueryResultsAllowImputedShingle() { + return new Object[] { + new Object[] { + asList(Optional.empty(), Optional.empty(), Optional.empty()), + 3, + Optional.of(asList(Optional.of(new double[] { 1 }), Optional.empty(), Optional.of(new double[] { 3 }))), + new double[] { 1, 3, 3 } }, + new Object[] { + asList(Optional.of(new double[] { 1 }), Optional.empty(), Optional.of(new double[] { 5 })), + 1, + Optional.of(asList(Optional.of(new double[] { 3 }))), + new double[] { 5, 5, 3 } }, + new Object[] { + asList(Optional.empty(), Optional.of(new double[] { 1 }), Optional.empty()), + 1, + Optional.of(asList(Optional.of(new double[] { 2 }))), + new double[] { 1, 2, 2 } }, + new Object[] { + asList(Optional.empty(), Optional.empty(), Optional.of(new double[] { 1 })), + 2, + Optional.of(asList(Optional.empty(), Optional.of(new double[] { 2 }))), + new double[] { 1, 2, 2 } }, + new Object[] { + asList(Optional.of(new double[] { 5, 6 }), Optional.empty(), Optional.empty()), + 2, + Optional.of(asList(Optional.of(new double[] { 3, 4 }), Optional.of(new double[] { 1, 2 }))), + new double[] { 3, 4, 3, 4, 1, 2 } }, }; + } + + @Test + @Parameters(method = "getCurrentFeaturesTestData_whenAfterQueryResultsAllowImputedShingle") + public void getCurrentFeatures_returnExpectedProcessedFeatures_whenAfterQueryResultsAllowImputedShingle( + List> preQueryResponse, + long intervalOffsetFromPreviousQuery, + Optional>> testQueryResponse, + double[] expectedProcessedFeatures + ) throws IOException { + int expectedNumQueriesToSearchFeatureDao = 2; + long previousStartTime = (shingleSize + 1) * intervalInMilliseconds; + long previousEndTime = previousStartTime + intervalInMilliseconds; + long testStartTime = previousStartTime + (intervalOffsetFromPreviousQuery * intervalInMilliseconds); + long testEndTime = testStartTime + intervalInMilliseconds; + + // Set up + setupSearchFeatureDaoForGetCurrentFeatures(preQueryResponse, testQueryResponse); + featureManager.getCurrentFeatures(detector, previousStartTime, previousEndTime, mock(ActionListener.class)); + + // Start test + SinglePointFeatures listenerResponse = getCurrentFeatures(detector, testStartTime, testEndTime); + verify(searchFeatureDao, times(expectedNumQueriesToSearchFeatureDao)) + .getFeatureSamplesForPeriods(eq(detector), any(List.class), any(ActionListener.class)); + assertTrue(listenerResponse.getUnprocessedFeatures().isPresent()); + assertTrue(listenerResponse.getProcessedFeatures().isPresent()); + + double[] actualProcessedFeatures = listenerResponse.getProcessedFeatures().get(); + for (int i = 0; i < expectedProcessedFeatures.length; i++) { + assertEquals(expectedProcessedFeatures[i], actualProcessedFeatures[i], 0); + } + } + + private Object[] getCurrentFeaturesTestData_whenMissingCurrentDataPoint() { + return new Object[] { + new Object[] { + asList(Optional.empty(), Optional.empty(), Optional.empty()), + 3, + Optional.of(asList(Optional.of(new double[] { 1 }), Optional.of(new double[] { 3 }), Optional.empty())), }, + new Object[] { + asList(Optional.of(new double[] { 1 }), Optional.of(new double[] { 1 }), Optional.empty()), + 1, + Optional.of(asList(Optional.empty())), }, + new Object[] { + asList(Optional.empty(), Optional.empty(), Optional.of(new double[] { 1, 2, 3 })), + 2, + Optional.of(asList(Optional.of(new double[] { 4, 5, 6 }), Optional.empty())), } }; + } + + @Test + @Parameters(method = "getCurrentFeaturesTestData_whenMissingCurrentDataPoint") + public void getCurrentFeatures_returnNoProcessedOrUnprocessedFeatures_whenMissingCurrentDataPoint( + List> preQueryResponse, + long intervalOffsetFromPreviousQuery, + Optional>> testQueryResponse + ) throws IOException { + int expectedNumQueriesToSearchFeatureDao = 2; + long previousStartTime = shingleSize * intervalInMilliseconds; + long previousEndTime = previousStartTime + intervalInMilliseconds; + long testStartTime = previousStartTime + intervalOffsetFromPreviousQuery * intervalInMilliseconds; + long testEndTime = testStartTime + intervalInMilliseconds; + + // Set up + setupSearchFeatureDaoForGetCurrentFeatures(preQueryResponse, testQueryResponse); + featureManager.getCurrentFeatures(detector, previousStartTime, previousEndTime, mock(ActionListener.class)); + + // Start test + SinglePointFeatures listenerResponse = getCurrentFeatures(detector, testStartTime, testEndTime); + verify(searchFeatureDao, times(expectedNumQueriesToSearchFeatureDao)) + .getFeatureSamplesForPeriods(eq(detector), any(List.class), any(ActionListener.class)); + assertFalse(listenerResponse.getUnprocessedFeatures().isPresent()); + assertFalse(listenerResponse.getProcessedFeatures().isPresent()); + } + + private Object[] getCurrentFeaturesTestData_whenAfterQueryResultsCannotBeShingled() { + return new Object[] { + new Object[] { + asList(Optional.of(new double[] { 1 }), Optional.of(new double[] { 2 }), Optional.of(new double[] { 3 })), + 3, + Optional.of(asList(Optional.empty(), Optional.empty(), Optional.of(new double[] { 4 }))), }, + new Object[] { + asList(Optional.of(new double[] { 1, 2 }), Optional.empty(), Optional.empty()), + 1, + Optional.of(asList(Optional.of(new double[] { 3, 4 }))), } }; + } + + @Test + @Parameters(method = "getCurrentFeaturesTestData_whenAfterQueryResultsCannotBeShingled") + public void getCurrentFeatures_returnNoProcessedFeatures_whenAfterQueryResultsCannotBeShingled( + List> preQueryResponse, + long intervalOffsetFromPreviousQuery, + Optional>> testQueryResponse + ) throws IOException { + int expectedNumQueriesToSearchFeatureDao = 2; + long previousStartTime = shingleSize * intervalInMilliseconds; + long previousEndTime = previousStartTime + intervalInMilliseconds; + long testStartTime = previousStartTime + intervalOffsetFromPreviousQuery * intervalInMilliseconds; + long testEndTime = testStartTime + intervalInMilliseconds; + + // Set up + setupSearchFeatureDaoForGetCurrentFeatures(preQueryResponse, testQueryResponse); + featureManager.getCurrentFeatures(detector, previousStartTime, previousEndTime, mock(ActionListener.class)); + + // Start test + SinglePointFeatures listenerResponse = getCurrentFeatures(detector, testStartTime, testEndTime); + verify(searchFeatureDao, times(expectedNumQueriesToSearchFeatureDao)) + .getFeatureSamplesForPeriods(eq(detector), any(List.class), any(ActionListener.class)); + assertTrue(listenerResponse.getUnprocessedFeatures().isPresent()); + assertFalse(listenerResponse.getProcessedFeatures().isPresent()); + } + + private Object[] getCurrentFeaturesTestData_whenQueryThrowsIOException() { + return new Object[] { + new Object[] { asList(Optional.empty(), Optional.empty(), Optional.empty()), 3 }, + new Object[] { asList(Optional.empty(), Optional.of(new double[] { 1, 2 }), Optional.of(new double[] { 3, 4 })), 1 } }; + } + + @Test + @Parameters(method = "getCurrentFeaturesTestData_whenQueryThrowsIOException") + public void getCurrentFeatures_returnExceptionToListener_whenQueryThrowsIOException( + List> preQueryResponse, + long intervalOffsetFromPreviousQuery + ) throws IOException { + int expectedNumQueriesToSearchFeatureDao = 2; + long previousStartTime = shingleSize * intervalInMilliseconds; + long previousEndTime = previousStartTime + intervalInMilliseconds; + long testStartTime = previousStartTime + intervalOffsetFromPreviousQuery * intervalInMilliseconds; + long testEndTime = testStartTime + intervalInMilliseconds; + + // Set up + setupSearchFeatureDaoForGetCurrentFeatures(preQueryResponse, Optional.empty()); + featureManager.getCurrentFeatures(detector, previousStartTime, previousEndTime, mock(ActionListener.class)); + + // Start test + ActionListener listener = mock(ActionListener.class); + featureManager.getCurrentFeatures(detector, testStartTime, testEndTime, listener); + verify(searchFeatureDao, times(expectedNumQueriesToSearchFeatureDao)) + .getFeatureSamplesForPeriods(eq(detector), any(List.class), any(ActionListener.class)); + verify(listener).onFailure(any(IOException.class)); + } + + private Object[] getCurrentFeaturesTestData_cacheMissingData() { + return new Object[] { + new Object[] { + asList(Optional.empty(), Optional.empty(), Optional.empty()), + Optional.of(asList(Optional.of(new double[] { 1 }))), + Optional.empty() }, + new Object[] { + asList(Optional.of(new double[] { 1, 2 }), Optional.empty(), Optional.of(new double[] { 3, 4 })), + Optional.of(asList(Optional.of(new double[] { 5, 6 }))), + Optional.of(new double[] { 3, 4, 3, 4, 5, 6 }) } }; + } + + @Test + @Parameters(method = "getCurrentFeaturesTestData_cacheMissingData") + public void getCurrentFeatures_returnExpectedFeatures_cacheMissingData( + List> firstQueryResponseToBeCached, + Optional>> secondQueryResponse, + Optional expectedProcessedFeaturesOptional + ) throws IOException { + long firstStartTime = shingleSize * intervalInMilliseconds; + long firstEndTime = firstStartTime + intervalInMilliseconds; + long secondStartTime = firstEndTime; + long secondEndTime = secondStartTime + intervalInMilliseconds; + + setupSearchFeatureDaoForGetCurrentFeatures(firstQueryResponseToBeCached, secondQueryResponse); + + // first call to cache missing points + featureManager.getCurrentFeatures(detector, firstStartTime, firstEndTime, mock(ActionListener.class)); + verify(searchFeatureDao, times(1)) + .getFeatureSamplesForPeriods(eq(detector), argThat(list -> list.size() == shingleSize), any(ActionListener.class)); + + // second call should only fetch current point even if previous points missing + SinglePointFeatures listenerResponse = getCurrentFeatures(detector, secondStartTime, secondEndTime); + verify(searchFeatureDao, times(1)) + .getFeatureSamplesForPeriods(eq(detector), argThat(list -> list.size() == 1), any(ActionListener.class)); + + assertTrue(listenerResponse.getUnprocessedFeatures().isPresent()); + if (expectedProcessedFeaturesOptional.isPresent()) { + assertTrue(listenerResponse.getProcessedFeatures().isPresent()); + double[] expectedProcessedFeatures = expectedProcessedFeaturesOptional.get(); + double[] actualProcessedFeatures = listenerResponse.getProcessedFeatures().get(); + for (int i = 0; i < expectedProcessedFeatures.length; i++) { + assertEquals(expectedProcessedFeatures[i], actualProcessedFeatures[i], 0); + } + } else { + assertFalse(listenerResponse.getProcessedFeatures().isPresent()); + } + } + + private Object[] getCurrentFeaturesTestData_withTimeJitterUpToHalfInterval() { + return new Object[] { + new Object[] { + asList(Optional.empty(), Optional.empty(), Optional.of(new double[] { 1 })), + 2.1, + Optional.of(asList(Optional.of(new double[] { 2 }), Optional.of(new double[] { 3 }))), + new double[] { 1, 2, 3 } }, + new Object[] { + asList(Optional.of(new double[] { 1 }), Optional.empty(), Optional.of(new double[] { 5 })), + 0.8, + Optional.of(asList(Optional.of(new double[] { 3 }))), + new double[] { 5, 5, 3 } }, + new Object[] { + asList(Optional.empty(), Optional.empty(), Optional.of(new double[] { 1 })), + 1.49, + Optional.of(asList(Optional.of(new double[] { 2 }))), + new double[] { 1, 1, 2 } }, + new Object[] { + asList(Optional.empty(), Optional.empty(), Optional.of(new double[] { 1 })), + 1.51, + Optional.of(asList(Optional.empty(), Optional.of(new double[] { 2 }))), + new double[] { 1, 1, 2 } }, + new Object[] { + asList(Optional.empty(), Optional.empty(), Optional.of(new double[] { 1 })), + 2.49, + Optional.of(asList(Optional.empty(), Optional.of(new double[] { 2 }))), + new double[] { 1, 2, 2 } }, + new Object[] { + asList(Optional.of(new double[] { 1, 2 }), Optional.of(new double[] { 3, 4 }), Optional.of(new double[] { 5, 6 })), + 2.5, + Optional + .of( + asList( + Optional.of(new double[] { 7, 8 }), + Optional.of(new double[] { 9, 10 }), + Optional.of(new double[] { 11, 12 }) + ) + ), + new double[] { 7, 8, 9, 10, 11, 12 } }, }; + } + + @Test + @Parameters(method = "getCurrentFeaturesTestData_withTimeJitterUpToHalfInterval") + public void getCurrentFeatures_returnExpectedFeatures_withTimeJitterUpToHalfInterval( + List> preQueryResponse, + double intervalOffsetFromPreviousQuery, + Optional>> testQueryResponse, + double[] expectedProcessedFeatures + ) throws IOException { + int expectedNumQueriesToSearchFeatureDao = 2; + long previousStartTime = (shingleSize + 1) * intervalInMilliseconds; + long previousEndTime = previousStartTime + intervalInMilliseconds; + double millisecondsOffset = intervalOffsetFromPreviousQuery * intervalInMilliseconds; + long testStartTime = previousStartTime + (long) millisecondsOffset; + long testEndTime = testStartTime + intervalInMilliseconds; + + // Set up + setupSearchFeatureDaoForGetCurrentFeatures(preQueryResponse, testQueryResponse); + featureManager.getCurrentFeatures(detector, previousStartTime, previousEndTime, mock(ActionListener.class)); + + // Start test + SinglePointFeatures listenerResponse = getCurrentFeatures(detector, testStartTime, testEndTime); + verify(searchFeatureDao, times(expectedNumQueriesToSearchFeatureDao)) + .getFeatureSamplesForPeriods(eq(detector), any(List.class), any(ActionListener.class)); + assertTrue(listenerResponse.getUnprocessedFeatures().isPresent()); + assertTrue(listenerResponse.getProcessedFeatures().isPresent()); + + double[] actualProcessedFeatures = listenerResponse.getProcessedFeatures().get(); + for (int i = 0; i < expectedProcessedFeatures.length; i++) { + assertEquals(expectedProcessedFeatures[i], actualProcessedFeatures[i], 0); + } + } + + private Entry entry(K key, V value) { + return new SimpleEntry<>(key, value); + } + + private Optional ar(double... values) { + if (values.length == 0) { + return Optional.empty(); + } else { + return Optional.of(values); + } + } + + private Object[] getCurrentFeaturesTestData_setsShingleSizeFromDetectorConfig() { + return new Object[] { new Object[] { 1 }, new Object[] { 4 }, new Object[] { 8 }, new Object[] { 20 } }; + } + + @Test + @Parameters(method = "getCurrentFeaturesTestData_setsShingleSizeFromDetectorConfig") + public void getCurrentFeatures_setsShingleSizeFromDetectorConfig(int shingleSize) throws IOException { + when(detector.getShingleSize()).thenReturn(shingleSize); + + doAnswer(invocation -> { + List> ranges = invocation.getArgument(1); + assertEquals(ranges.size(), shingleSize); + + ActionListener>> daoListener = invocation.getArgument(2); + List> response = new ArrayList>(); + for (int i = 0; i < ranges.size(); i++) { + response.add(Optional.of(new double[] { i })); + } + daoListener.onResponse(response); + return null; + }).when(searchFeatureDao).getFeatureSamplesForPeriods(eq(detector), any(List.class), any(ActionListener.class)); + + SinglePointFeatures listenerResponse = getCurrentFeatures(detector, 0, intervalInMilliseconds); + assertTrue(listenerResponse.getProcessedFeatures().isPresent()); + assertEquals(listenerResponse.getProcessedFeatures().get().length, shingleSize); + assertEquals(featureManager.getShingleSize(detector.getId()), shingleSize); + } + + @Test + public void testGetShingledFeatureForHistoricalAnalysisFromEmptyShingleWithoutMissingData() { + long millisecondsPerMinute = 60000; + int shingleSize = 8; + when(detector.getShingleSize()).thenReturn(shingleSize); + + Deque>> shingle = new ArrayDeque<>(); + + long endTime = Instant.now().toEpochMilli(); + int i = 0; + for (; i < shingleSize - MAX_IMPUTATION_NEIGHBOR_DISTANCE; i++) { + double[] testData = new double[] { i }; + Optional dataPoint = Optional.of(testData); + SinglePointFeatures feature = featureManager.getShingledFeatureForHistoricalAnalysis(detector, shingle, dataPoint, endTime); + endTime += millisecondsPerMinute; + + assertTrue(Arrays.equals(testData, feature.getUnprocessedFeatures().get())); + assertFalse(feature.getProcessedFeatures().isPresent()); + } + + double[] testData = new double[] { i++ }; + Optional dataPoint = Optional.of(testData); + SinglePointFeatures feature = featureManager.getShingledFeatureForHistoricalAnalysis(detector, shingle, dataPoint, endTime); + assertTrue(feature.getProcessedFeatures().isPresent()); + assertTrue(Arrays.equals(new double[] { 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 }, feature.getProcessedFeatures().get())); + + endTime += millisecondsPerMinute; + testData = new double[] { i++ }; + dataPoint = Optional.of(testData); + feature = featureManager.getShingledFeatureForHistoricalAnalysis(detector, shingle, dataPoint, endTime); + assertTrue(feature.getProcessedFeatures().isPresent()); + assertTrue(Arrays.equals(new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0 }, feature.getProcessedFeatures().get())); + + for (; i < 2 * shingleSize; i++) { + endTime += millisecondsPerMinute; + SinglePointFeatures singlePointFeatures = featureManager + .getShingledFeatureForHistoricalAnalysis(detector, shingle, Optional.of(new double[] { i }), endTime); + assertTrue(singlePointFeatures.getProcessedFeatures().isPresent()); + assertTrue( + Arrays + .equals( + new double[] { i - 7, i - 6, i - 5, i - 4, i - 3, i - 2, i - 1, i }, + singlePointFeatures.getProcessedFeatures().get() + ) + ); + } + } + + @Test + public void testGetShingledFeatureForHistoricalAnalysisWithTooManyMissingData() { + long millisecondsPerMinute = 60000; + int shingleSize = 8; + when(detector.getShingleSize()).thenReturn(shingleSize); + + Deque>> shingle = new ArrayDeque<>(); + + long endTime = Instant.now().toEpochMilli(); + int i = 0; + for (; i < shingleSize; i++) { + featureManager.getShingledFeatureForHistoricalAnalysis(detector, shingle, Optional.of(new double[] { i }), endTime); + endTime += millisecondsPerMinute; + } + + for (int j = 0; j < MAX_IMPUTATION_NEIGHBOR_DISTANCE + 1; j++) { + SinglePointFeatures feature = featureManager + .getShingledFeatureForHistoricalAnalysis(detector, shingle, Optional.empty(), endTime); + endTime += millisecondsPerMinute; + assertFalse(feature.getProcessedFeatures().isPresent()); + } + SinglePointFeatures feature = featureManager + .getShingledFeatureForHistoricalAnalysis(detector, shingle, Optional.of(new double[] { i }), endTime); + assertFalse(feature.getProcessedFeatures().isPresent()); + } + + @Test + public void testGetShingledFeatureForHistoricalAnalysisWithOneMissingData() { + long millisecondsPerMinute = 60000; + int shingleSize = 8; + when(detector.getShingleSize()).thenReturn(shingleSize); + + Deque>> shingle = new ArrayDeque<>(); + + long endTime = Instant.now().toEpochMilli(); + int i = 0; + for (; i < shingleSize; i++) { + featureManager.getShingledFeatureForHistoricalAnalysis(detector, shingle, Optional.of(new double[] { i }), endTime); + endTime += millisecondsPerMinute; + } + + SinglePointFeatures feature1 = featureManager.getShingledFeatureForHistoricalAnalysis(detector, shingle, Optional.empty(), endTime); + assertFalse(feature1.getProcessedFeatures().isPresent()); + + endTime += millisecondsPerMinute; + SinglePointFeatures feature2 = featureManager + .getShingledFeatureForHistoricalAnalysis(detector, shingle, Optional.of(new double[] { i }), endTime); + assertTrue(feature2.getProcessedFeatures().isPresent()); + assertTrue(Arrays.equals(new double[] { 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 8.0 }, feature2.getProcessedFeatures().get())); + } +} diff --git a/src/test/java/org/opensearch/ad/feature/FeaturesTests.java-e b/src/test/java/org/opensearch/ad/feature/FeaturesTests.java-e new file mode 100644 index 000000000..447bdd6c4 --- /dev/null +++ b/src/test/java/org/opensearch/ad/feature/FeaturesTests.java-e @@ -0,0 +1,74 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.feature; + +import static org.junit.Assert.assertEquals; + +import java.util.AbstractMap.SimpleEntry; +import java.util.Arrays; +import java.util.List; +import java.util.Map.Entry; + +import junitparams.JUnitParamsRunner; +import junitparams.Parameters; + +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(JUnitParamsRunner.class) +public class FeaturesTests { + + private List> ranges = Arrays.asList(new SimpleEntry<>(0L, 1L)); + private double[][] unprocessed = new double[][] { { 1, 2 } }; + private double[][] processed = new double[][] { { 3, 4 } }; + + private Features features = new Features(ranges, unprocessed, processed); + + @Test + public void getters_returnExcepted() { + assertEquals(ranges, features.getTimeRanges()); + assertEquals(unprocessed, features.getUnprocessedFeatures()); + assertEquals(processed, features.getProcessedFeatures()); + } + + private Object[] equalsData() { + return new Object[] { + new Object[] { features, features, true }, + new Object[] { features, new Features(ranges, unprocessed, processed), true }, + new Object[] { features, null, false }, + new Object[] { features, "testString", false }, + new Object[] { features, new Features(null, unprocessed, processed), false }, + new Object[] { features, new Features(ranges, null, processed), false }, + new Object[] { features, new Features(ranges, unprocessed, null), false }, }; + } + + @Test + @Parameters(method = "equalsData") + public void equals_returnExpected(Features result, Object other, boolean expected) { + assertEquals(expected, result.equals(other)); + } + + private Object[] hashCodeData() { + Features features = new Features(ranges, unprocessed, processed); + return new Object[] { + new Object[] { features, new Features(ranges, unprocessed, processed), true }, + new Object[] { features, new Features(null, unprocessed, processed), false }, + new Object[] { features, new Features(ranges, null, processed), false }, + new Object[] { features, new Features(ranges, unprocessed, null), false }, }; + } + + @Test + @Parameters(method = "hashCodeData") + public void hashCode_returnExpected(Features result, Features other, boolean expected) { + assertEquals(expected, result.hashCode() == other.hashCode()); + } +} diff --git a/src/test/java/org/opensearch/ad/feature/NoPowermockSearchFeatureDaoTests.java-e b/src/test/java/org/opensearch/ad/feature/NoPowermockSearchFeatureDaoTests.java-e new file mode 100644 index 000000000..1d0da6d19 --- /dev/null +++ b/src/test/java/org/opensearch/ad/feature/NoPowermockSearchFeatureDaoTests.java-e @@ -0,0 +1,656 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.feature; + +import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.nio.charset.StandardCharsets; +import java.time.Clock; +import java.time.ZoneOffset; +import java.time.temporal.ChronoUnit; +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.util.BytesRef; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponse.Clusters; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.lease.Releasables; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.time.DateFormatter; +import org.opensearch.common.util.MockBigArrays; +import org.opensearch.common.util.MockPageCacheRecycler; +import org.opensearch.index.mapper.DateFieldMapper; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.indices.breaker.NoneCircuitBreakerService; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.aggregations.BucketOrder; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.aggregations.InternalOrder; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; +import org.opensearch.search.aggregations.bucket.filter.InternalFilter; +import org.opensearch.search.aggregations.bucket.filter.InternalFilters; +import org.opensearch.search.aggregations.bucket.filter.InternalFilters.InternalBucket; +import org.opensearch.search.aggregations.bucket.range.InternalDateRange; +import org.opensearch.search.aggregations.bucket.terms.StringTerms; +import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.opensearch.search.aggregations.metrics.AbstractHyperLogLog; +import org.opensearch.search.aggregations.metrics.AbstractHyperLogLogPlusPlus; +import org.opensearch.search.aggregations.metrics.HyperLogLogPlusPlus; +import org.opensearch.search.aggregations.metrics.InternalCardinality; +import org.opensearch.search.aggregations.metrics.InternalMax; +import org.opensearch.search.aggregations.metrics.SumAggregationBuilder; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.dataprocessor.Imputer; +import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; + +import com.carrotsearch.hppc.BitMixer; +import com.google.common.collect.ImmutableList; + +/** + * SearchFeatureDaoTests uses Powermock and has strange log4j related errors. + * Create a new class for new tests related to SearchFeatureDao. + * + */ +public class NoPowermockSearchFeatureDaoTests extends AbstractTimeSeriesTest { + private final Logger LOG = LogManager.getLogger(NoPowermockSearchFeatureDaoTests.class); + + private AnomalyDetector detector; + private Client client; + private SearchFeatureDao searchFeatureDao; + private Imputer imputer; + private SecurityClientUtil clientUtil; + private Settings settings; + private ClusterService clusterService; + private Clock clock; + private String serviceField, hostField; + private String detectorId; + private Map attrs1, attrs2; + + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(NoPowermockSearchFeatureDaoTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + + @Override + public void setUp() throws Exception { + super.setUp(); + serviceField = "service"; + hostField = "host"; + + detector = mock(AnomalyDetector.class); + when(detector.isHighCardinality()).thenReturn(true); + when(detector.getCategoryFields()).thenReturn(Arrays.asList(new String[] { serviceField, hostField })); + detectorId = "123"; + when(detector.getId()).thenReturn(detectorId); + when(detector.getTimeField()).thenReturn("testTimeField"); + when(detector.getIndices()).thenReturn(Arrays.asList("testIndices")); + IntervalTimeConfiguration detectionInterval = new IntervalTimeConfiguration(1, ChronoUnit.MINUTES); + when(detector.getInterval()).thenReturn(detectionInterval); + when(detector.getFilterQuery()).thenReturn(QueryBuilders.matchAllQuery()); + + client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + + imputer = new LinearUniformImputer(false); + + settings = Settings.EMPTY; + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections + .unmodifiableSet( + new HashSet<>(Arrays.asList(AnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW, AnomalyDetectorSettings.PAGE_SIZE)) + ) + ); + clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + clock = mock(Clock.class); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + clientUtil = new SecurityClientUtil(nodeStateManager, settings); + + searchFeatureDao = new SearchFeatureDao( + client, + xContentRegistry(), // Important. Without this, ParseUtils cannot parse anything + imputer, + clientUtil, + settings, + clusterService, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + clock, + 1, + 1, + 60_000L + ); + + String app0 = "app_0"; + String server1 = "server_1"; + + attrs1 = new HashMap<>(); + attrs1.put(serviceField, app0); + attrs1.put(hostField, server1); + + String server2 = "server_2"; + attrs1 = new HashMap<>(); + attrs1.put(serviceField, app0); + attrs1.put(hostField, server2); + } + + private SearchResponse createPageResponse(Map attrs) { + CompositeAggregation pageOneComposite = mock(CompositeAggregation.class); + when(pageOneComposite.getName()).thenReturn(SearchFeatureDao.AGG_NAME_TOP); + when(pageOneComposite.afterKey()).thenReturn(attrs); + + List pageOneBuckets = new ArrayList<>(); + CompositeAggregation.Bucket bucket = mock(CompositeAggregation.Bucket.class); + when(bucket.getKey()).thenReturn(attrs); + when(bucket.getDocCount()).thenReturn(1552L); + pageOneBuckets.add(bucket); + + when(pageOneComposite.getBuckets()) + .thenAnswer((Answer>) invocation -> { return pageOneBuckets; }); + + Aggregations pageOneAggs = new Aggregations(Collections.singletonList(pageOneComposite)); + + SearchResponseSections pageOneSections = new SearchResponseSections(SearchHits.empty(), pageOneAggs, null, false, null, null, 1); + + return new SearchResponse(pageOneSections, null, 1, 1, 0, 0, ShardSearchFailure.EMPTY_ARRAY, Clusters.EMPTY); + } + + @SuppressWarnings("unchecked") + @Test + public void testGetHighestCountEntitiesUsingTermsAgg() { + SearchHits hits = new SearchHits(new SearchHit[] {}, null, Float.NaN); + + String entity1Name = "value1"; + long entity1Count = 3; + StringTerms.Bucket entity1Bucket = new StringTerms.Bucket( + new BytesRef(entity1Name.getBytes(StandardCharsets.UTF_8), 0, entity1Name.getBytes(StandardCharsets.UTF_8).length), + entity1Count, + null, + false, + 0L, + DocValueFormat.RAW + ); + String entity2Name = "value2"; + long entity2Count = 1; + StringTerms.Bucket entity2Bucket = new StringTerms.Bucket( + new BytesRef(entity2Name.getBytes(StandardCharsets.UTF_8), 0, entity2Name.getBytes(StandardCharsets.UTF_8).length), + entity2Count, + null, + false, + 0, + DocValueFormat.RAW + ); + List stringBuckets = ImmutableList.of(entity1Bucket, entity2Bucket); + StringTerms termsAgg = new StringTerms( + // "term_agg", + SearchFeatureDao.AGG_NAME_TOP, + InternalOrder.key(false), + BucketOrder.count(false), + 1, + 0, + Collections.emptyMap(), + DocValueFormat.RAW, + 1, + false, + 0, + stringBuckets, + 0 + ); + + InternalAggregations internalAggregations = InternalAggregations.from(Collections.singletonList(termsAgg)); + + SearchResponseSections searchSections = new SearchResponseSections(hits, internalAggregations, null, false, false, null, 1); + + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 30, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + + doAnswer(invocation -> { + SearchRequest request = invocation.getArgument(0); + assertEquals(1, request.indices().length); + assertTrue(detector.getIndices().contains(request.indices()[0])); + AggregatorFactories.Builder aggs = request.source().aggregations(); + assertEquals(1, aggs.count()); + Collection factory = aggs.getAggregatorFactories(); + assertTrue(!factory.isEmpty()); + assertThat(factory.iterator().next(), instanceOf(TermsAggregationBuilder.class)); + + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), any(ActionListener.class)); + + String categoryField = "fieldName"; + when(detector.getCategoryFields()).thenReturn(Collections.singletonList(categoryField)); + ActionListener> listener = mock(ActionListener.class); + searchFeatureDao.getHighestCountEntities(detector, 10L, 20L, listener); + + ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); + verify(listener).onResponse(captor.capture()); + List result = captor.getValue(); + assertEquals(2, result.size()); + assertEquals(Entity.createSingleAttributeEntity(categoryField, entity1Name), result.get(0)); + assertEquals(Entity.createSingleAttributeEntity(categoryField, entity2Name), result.get(1)); + } + + @SuppressWarnings("unchecked") + public void testGetHighestCountEntitiesUsingPagination() { + SearchResponse response1 = createPageResponse(attrs1); + + CountDownLatch inProgress = new CountDownLatch(1); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + + inProgress.countDown(); + listener.onResponse(response1); + + return null; + }).when(client).search(any(), any()); + + ActionListener> listener = mock(ActionListener.class); + + searchFeatureDao.getHighestCountEntities(detector, 10L, 20L, listener); + + ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); + verify(listener).onResponse(captor.capture()); + List result = captor.getValue(); + assertEquals(1, result.size()); + assertEquals(Entity.createEntityByReordering(attrs1), result.get(0)); + } + + @SuppressWarnings("unchecked") + public void testGetHighestCountEntitiesExhaustedPages() throws InterruptedException { + SearchResponse response1 = createPageResponse(attrs1); + + CompositeAggregation emptyComposite = mock(CompositeAggregation.class); + when(emptyComposite.getName()).thenReturn(SearchFeatureDao.AGG_NAME_TOP); + when(emptyComposite.afterKey()).thenReturn(null); + // empty bucket + when(emptyComposite.getBuckets()) + .thenAnswer((Answer>) invocation -> { return new ArrayList(); }); + Aggregations emptyAggs = new Aggregations(Collections.singletonList(emptyComposite)); + SearchResponseSections emptySections = new SearchResponseSections(SearchHits.empty(), emptyAggs, null, false, null, null, 1); + SearchResponse emptyResponse = new SearchResponse(emptySections, null, 1, 1, 0, 0, ShardSearchFailure.EMPTY_ARRAY, Clusters.EMPTY); + + CountDownLatch inProgress = new CountDownLatch(2); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + inProgress.countDown(); + if (inProgress.getCount() == 1) { + listener.onResponse(response1); + } else { + listener.onResponse(emptyResponse); + } + + return null; + }).when(client).search(any(), any()); + + ActionListener> listener = mock(ActionListener.class); + + searchFeatureDao = new SearchFeatureDao( + client, + xContentRegistry(), + imputer, + clientUtil, + settings, + clusterService, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + clock, + 2, + 1, + 60_000L + ); + + searchFeatureDao.getHighestCountEntities(detector, 10L, 20L, listener); + + ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); + verify(listener).onResponse(captor.capture()); + List result = captor.getValue(); + assertEquals(1, result.size()); + assertEquals(Entity.createEntityByReordering(attrs1), result.get(0)); + // both counts are used in client.search + assertTrue(inProgress.await(10000L, TimeUnit.MILLISECONDS)); + } + + @SuppressWarnings("unchecked") + public void testGetHighestCountEntitiesNotEnoughTime() throws InterruptedException { + SearchResponse response1 = createPageResponse(attrs1); + SearchResponse response2 = createPageResponse(attrs2); + + CountDownLatch inProgress = new CountDownLatch(2); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + inProgress.countDown(); + if (inProgress.getCount() == 1) { + listener.onResponse(response1); + } else { + listener.onResponse(response2); + } + + return null; + }).when(client).search(any(), any()); + + ActionListener> listener = mock(ActionListener.class); + + long timeoutMillis = 60_000L; + searchFeatureDao = new SearchFeatureDao( + client, + xContentRegistry(), + imputer, + clientUtil, + settings, + clusterService, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + clock, + 2, + 1, + timeoutMillis + ); + + CountDownLatch clockInvoked = new CountDownLatch(2); + + when(clock.millis()).thenAnswer(new Answer() { + @Override + public Long answer(InvocationOnMock invocation) throws Throwable { + clockInvoked.countDown(); + if (clockInvoked.getCount() == 1) { + return 1L; + } else { + return 2L + timeoutMillis; + } + } + }); + + searchFeatureDao.getHighestCountEntities(detector, 10L, 20L, listener); + + ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); + verify(listener).onResponse(captor.capture()); + List result = captor.getValue(); + assertEquals(1, result.size()); + assertEquals(Entity.createEntityByReordering(attrs1), result.get(0)); + // exited early due to timeout + assertEquals(1, inProgress.getCount()); + // first called to create expired time; second called to check if time has expired + assertTrue(clockInvoked.await(10000L, TimeUnit.MILLISECONDS)); + } + + @SuppressWarnings("unchecked") + public void getColdStartSamplesForPeriodsTemplate(DocValueFormat format) throws IOException, InterruptedException { + detector = TestHelpers.AnomalyDetectorBuilder + .newInstance() + .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) + .setFeatureAttributes( + Collections.singletonList(new Feature("deny_sum", "deny sum", true, new SumAggregationBuilder("deny_sum").field("deny"))) + ) + .build(); + + InternalDateRange.Factory factory = new InternalDateRange.Factory(); + InternalDateRange.Bucket bucket1 = factory + .createBucket( + "1634786770964-1634786830964", + 1634786770964L, + 1634786830964L, + 1, + InternalAggregations.from(Arrays.asList(new InternalMax("deny_sum", 840.0, DocValueFormat.RAW, Collections.emptyMap()))), + false, + format + ); + InternalDateRange.Bucket bucket2 = factory + .createBucket( + "1634790370964-1634790430964", + 1634790370964L, + 1634790430964L, + 0, + InternalAggregations.from(Arrays.asList(new InternalMax("deny_sum", 0, DocValueFormat.RAW, Collections.emptyMap()))), + false, + format + ); + InternalDateRange.Bucket bucket3 = factory + .createBucket( + "1634793970964-1634794030964", + 1634793970964L, + 1634794030964L, + 1, + InternalAggregations.from(Arrays.asList(new InternalMax("deny_sum", 3489.0, DocValueFormat.RAW, Collections.emptyMap()))), + false, + format + ); + InternalDateRange range = factory + .create("date_range", Arrays.asList(bucket2, bucket3, bucket1), DocValueFormat.RAW, false, Collections.emptyMap()); + + InternalAggregations aggs = InternalAggregations.from(Arrays.asList(range)); + SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(2189, TotalHits.Relation.EQUAL_TO), Float.NaN); + InternalSearchResponse internalSearchResponse = new InternalSearchResponse(searchHits, aggs, null, null, false, null, 1); + SearchResponse response = new SearchResponse( + internalSearchResponse, + null, + 1, + 1, + 0, + 4, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(response); + return null; + }).when(client).search(any(), any(ActionListener.class)); + + List> sampleRanges = new ArrayList<>(); + sampleRanges.add(new SimpleImmutableEntry(1634793970964L, 1634794030964L)); + sampleRanges.add(new SimpleImmutableEntry(1634790370964L, 1634790430964L)); + sampleRanges.add(new SimpleImmutableEntry(1634786770964L, 1634786830964L)); + + CountDownLatch inProgressLatch = new CountDownLatch(1); + + // test that the results are in ascending order of time and zero doc results are not ignored + searchFeatureDao + .getColdStartSamplesForPeriods( + detector, + sampleRanges, + Entity.createSingleAttributeEntity("field", "abc"), + true, + ActionListener.wrap(samples -> { + assertEquals(3, samples.size()); + for (int i = 0; i < samples.size(); i++) { + Optional sample = samples.get(i); + double[] array = sample.get(); + assertEquals(1, array.length); + if (i == 0) { + assertEquals(840, array[0], 1e-10); + } else if (i == 1) { + assertEquals(0, array[0], 1e-10); + } else { + assertEquals(3489.0, array[0], 1e-10); + } + } + inProgressLatch.countDown(); + }, exception -> { + LOG.error("stack trace", exception); + assertTrue("Should not reach here ", false); + inProgressLatch.countDown(); + }) + ); + + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + + CountDownLatch inProgressLatch2 = new CountDownLatch(1); + + // test that the results are in ascending order of time and zero doc results are ignored + searchFeatureDao + .getColdStartSamplesForPeriods( + detector, + sampleRanges, + Entity.createSingleAttributeEntity("field", "abc"), + false, + ActionListener.wrap(samples -> { + assertEquals(2, samples.size()); + for (int i = 0; i < samples.size(); i++) { + Optional sample = samples.get(i); + double[] array = sample.get(); + assertEquals(1, array.length); + if (i == 0) { + assertEquals(840, array[0], 1e-10); + } else { + assertEquals(3489.0, array[0], 1e-10); + } + } + inProgressLatch2.countDown(); + }, exception -> { + LOG.error("stack trace", exception); + assertTrue("Should not reach here ", false); + inProgressLatch.countDown(); + }) + ); + + assertTrue(inProgressLatch2.await(100, TimeUnit.SECONDS)); + } + + public void testGetColdStartSamplesForPeriodsMillisFormat() throws IOException, InterruptedException { + DocValueFormat format = new DocValueFormat.DateTime( + DateFormatter.forPattern("epoch_millis"), + ZoneOffset.UTC, + DateFieldMapper.Resolution.MILLISECONDS + ); + getColdStartSamplesForPeriodsTemplate(format); + } + + public void testGetColdStartSamplesForPeriodsDefaultFormat() throws IOException, InterruptedException { + DocValueFormat format = new DocValueFormat.DateTime( + DateFieldMapper.DEFAULT_DATE_TIME_FORMATTER, + ZoneOffset.UTC, + DateFieldMapper.Resolution.MILLISECONDS + ); + getColdStartSamplesForPeriodsTemplate(format); + } + + public void testGetColdStartSamplesForPeriodsRawFormat() throws IOException, InterruptedException { + getColdStartSamplesForPeriodsTemplate(DocValueFormat.RAW); + } + + @SuppressWarnings("rawtypes") + public void testParseBuckets() throws InstantiationException, + IllegalAccessException, + IllegalArgumentException, + InvocationTargetException, + NoSuchMethodException, + SecurityException { + // cannot mock final class HyperLogLogPlusPlus + HyperLogLogPlusPlus hllpp = new HyperLogLogPlusPlus( + randomIntBetween(AbstractHyperLogLog.MIN_PRECISION, AbstractHyperLogLog.MAX_PRECISION), + new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()), + 1 + ); + long hash1 = BitMixer.mix64(randomIntBetween(1, 100)); + long hash2 = BitMixer.mix64(randomIntBetween(1, 100)); + hllpp.collect(0, hash1); + hllpp.collect(0, hash2); + + Constructor ctor = null; + ctor = InternalCardinality.class.getDeclaredConstructor(String.class, AbstractHyperLogLogPlusPlus.class, Map.class); + ctor.setAccessible(true); + InternalCardinality cardinality = (InternalCardinality) ctor.newInstance("impactUniqueAccounts", hllpp, new HashMap<>()); + + // have to use reflection as all of InternalFilter's constructor are not public + ctor = InternalFilter.class.getDeclaredConstructor(String.class, long.class, InternalAggregations.class, Map.class); + + ctor.setAccessible(true); + String featureId = "deny_max"; + InternalFilter internalFilter = (InternalFilter) ctor + .newInstance(featureId, 100, InternalAggregations.from(Arrays.asList(cardinality)), new HashMap<>()); + InternalBucket bucket = new InternalFilters.InternalBucket( + "test", + randomIntBetween(0, 1000), + InternalAggregations.from(Arrays.asList(internalFilter)), + true + ); + + Optional parsedResult = searchFeatureDao.parseBucket(bucket, Arrays.asList(featureId)); + + assertTrue(parsedResult.isPresent()); + double[] parsedCardinality = parsedResult.get(); + assertEquals(1, parsedCardinality.length); + double buckets = hash1 == hash2 ? 1 : 2; + assertEquals(buckets, parsedCardinality[0], 0.001); + + // release MockBigArrays; otherwise, test will fail + Releasables.close(hllpp); + } +} diff --git a/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoParamTests.java b/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoParamTests.java index 5776086b4..e00225ef0 100644 --- a/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoParamTests.java +++ b/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoParamTests.java @@ -52,7 +52,6 @@ import org.opensearch.action.search.MultiSearchResponse.Item; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; @@ -76,6 +75,7 @@ import org.opensearch.search.aggregations.metrics.Percentile; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.dataprocessor.Imputer; import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; @@ -160,7 +160,7 @@ public void setup() throws Exception { imputer = new LinearUniformImputer(false); ExecutorService executorService = mock(ExecutorService.class); - when(threadPool.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); + when(threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); doAnswer(invocation -> { Runnable runnable = invocation.getArgument(0); runnable.run(); diff --git a/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoParamTests.java-e b/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoParamTests.java-e new file mode 100644 index 000000000..e00225ef0 --- /dev/null +++ b/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoParamTests.java-e @@ -0,0 +1,439 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.feature; + +import static java.util.Arrays.asList; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.time.Clock; +import java.time.temporal.ChronoUnit; +import java.util.AbstractMap.SimpleEntry; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.concurrent.ExecutorService; + +import junitparams.JUnitParamsRunner; +import junitparams.Parameters; + +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionFuture; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.MultiSearchRequest; +import org.opensearch.action.search.MultiSearchResponse; +import org.opensearch.action.search.MultiSearchResponse.Item; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.script.ScriptService; +import org.opensearch.script.TemplateScript; +import org.opensearch.script.TemplateScript.Factory; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.metrics.InternalTDigestPercentiles; +import org.opensearch.search.aggregations.metrics.Max; +import org.opensearch.search.aggregations.metrics.Percentile; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.dataprocessor.Imputer; +import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.util.ParseUtils; +import org.powermock.api.mockito.PowerMockito; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; +import org.powermock.modules.junit4.PowerMockRunnerDelegate; + +import com.google.gson.Gson; + +/** + * Due to https://tinyurl.com/2y265s2w, tests with and without @Parameters annotation + * are incompatible with each other. This class tests SearchFeatureDao using @Parameters, + * while SearchFeatureDaoTests do not use @Parameters. + * + */ +@PowerMockIgnore("javax.management.*") +@RunWith(PowerMockRunner.class) +@PowerMockRunnerDelegate(JUnitParamsRunner.class) +@PrepareForTest({ ParseUtils.class, Gson.class }) +public class SearchFeatureDaoParamTests { + + private SearchFeatureDao searchFeatureDao; + + @Mock + private Client client; + @Mock + private ScriptService scriptService; + @Mock + private NamedXContentRegistry xContent; + private SecurityClientUtil clientUtil; + + @Mock + private Factory factory; + @Mock + private TemplateScript templateScript; + @Mock + private ActionFuture searchResponseFuture; + @Mock + private ActionFuture multiSearchResponseFuture; + @Mock + private SearchResponse searchResponse; + @Mock + private MultiSearchResponse multiSearchResponse; + @Mock + private Item multiSearchResponseItem; + @Mock + private Aggregations aggs; + @Mock + private Max max; + @Mock + private NodeStateManager stateManager; + + @Mock + private AnomalyDetector detector; + + @Mock + private ThreadPool threadPool; + + @Mock + private ClusterService clusterService; + + @Mock + private Clock clock; + + private SearchRequest searchRequest; + private SearchSourceBuilder searchSourceBuilder; + private MultiSearchRequest multiSearchRequest; + private IntervalTimeConfiguration detectionInterval; + private String detectorId; + private Imputer imputer; + private Settings settings; + + @Before + public void setup() throws Exception { + MockitoAnnotations.initMocks(this); + PowerMockito.mockStatic(ParseUtils.class); + + imputer = new LinearUniformImputer(false); + + ExecutorService executorService = mock(ExecutorService.class); + when(threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(executorService).execute(any(Runnable.class)); + + settings = Settings.EMPTY; + + when(client.threadPool()).thenReturn(threadPool); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + clientUtil = new SecurityClientUtil(nodeStateManager, settings); + searchFeatureDao = spy( + new SearchFeatureDao(client, xContent, imputer, clientUtil, settings, null, AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE) + ); + + detectionInterval = new IntervalTimeConfiguration(1, ChronoUnit.MINUTES); + detectorId = "123"; + + when(detector.getId()).thenReturn(detectorId); + when(detector.getTimeField()).thenReturn("testTimeField"); + when(detector.getIndices()).thenReturn(Arrays.asList("testIndices")); + when(detector.getInterval()).thenReturn(detectionInterval); + when(detector.getFilterQuery()).thenReturn(QueryBuilders.matchAllQuery()); + when(detector.getCategoryFields()).thenReturn(Collections.singletonList("a")); + + searchSourceBuilder = SearchSourceBuilder + .fromXContent(XContentType.JSON.xContent().createParser(xContent, LoggingDeprecationHandler.INSTANCE, "{}")); + searchRequest = new SearchRequest(detector.getIndices().toArray(new String[0])); + + when(max.getName()).thenReturn(CommonName.AGG_NAME_MAX_TIME); + List list = new ArrayList<>(); + list.add(max); + Aggregations aggregations = new Aggregations(list); + SearchHits hits = new SearchHits(new SearchHit[0], new TotalHits(1L, TotalHits.Relation.EQUAL_TO), 1f); + when(searchResponse.getHits()).thenReturn(hits); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + listener.onResponse(searchResponse); + return null; + }).when(client).search(eq(searchRequest), any()); + when(searchResponse.getAggregations()).thenReturn(aggregations); + + multiSearchRequest = new MultiSearchRequest(); + SearchRequest request = new SearchRequest(detector.getIndices().toArray(new String[0])); + multiSearchRequest.add(request); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + listener.onResponse(multiSearchResponse); + return null; + }).when(client).multiSearch(eq(multiSearchRequest), any()); + when(multiSearchResponse.getResponses()).thenReturn(new Item[] { multiSearchResponseItem }); + when(multiSearchResponseItem.getResponse()).thenReturn(searchResponse); + } + + @Test + @Parameters(method = "getFeaturesForPeriodData") + @SuppressWarnings("unchecked") + public void getFeaturesForPeriod_returnExpectedToListener(List aggs, List featureIds, double[] expected) + throws Exception { + + long start = 100L; + long end = 200L; + when(ParseUtils.generateInternalFeatureQuery(eq(detector), eq(start), eq(end), eq(xContent))).thenReturn(searchSourceBuilder); + when(searchResponse.getAggregations()).thenReturn(new Aggregations(aggs)); + when(detector.getEnabledFeatureIds()).thenReturn(featureIds); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(eq(searchRequest), any(ActionListener.class)); + + ActionListener> listener = mock(ActionListener.class); + searchFeatureDao.getFeaturesForPeriod(detector, start, end, listener); + + ArgumentCaptor> captor = ArgumentCaptor.forClass(Optional.class); + verify(listener).onResponse(captor.capture()); + Optional result = captor.getValue(); + assertTrue(Arrays.equals(expected, result.orElse(null))); + } + + @Test + @Parameters(method = "getFeaturesForSampledPeriodsData") + @SuppressWarnings("unchecked") + public void getFeaturesForSampledPeriods_returnExpectedToListener( + Long[][] queryRanges, + double[][] queryResults, + long endTime, + int maxStride, + int maxSamples, + Optional> expected + ) { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener.onResponse(Optional.empty()); + return null; + }).when(searchFeatureDao).getFeaturesForPeriod(any(), anyLong(), anyLong(), any(ActionListener.class)); + for (int i = 0; i < queryRanges.length; i++) { + double[] queryResult = queryResults[i]; + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener.onResponse(Optional.of(queryResult)); + return null; + }) + .when(searchFeatureDao) + .getFeaturesForPeriod(eq(detector), eq(queryRanges[i][0]), eq(queryRanges[i][1]), any(ActionListener.class)); + } + + ActionListener>> listener = mock(ActionListener.class); + searchFeatureDao.getFeaturesForSampledPeriods(detector, maxSamples, maxStride, endTime, listener); + + ArgumentCaptor>> captor = ArgumentCaptor.forClass(Optional.class); + verify(listener).onResponse(captor.capture()); + Optional> result = captor.getValue(); + assertEquals(expected.isPresent(), result.isPresent()); + if (expected.isPresent()) { + assertTrue(Arrays.deepEquals(expected.get().getKey(), result.get().getKey())); + assertEquals(expected.get().getValue(), result.get().getValue()); + } + } + + @SuppressWarnings("unchecked") + private Object[] getFeaturesForPeriodData() { + String maxName = "max"; + double maxValue = 2; + Max max = mock(Max.class); + when(max.value()).thenReturn(maxValue); + when(max.getName()).thenReturn(maxName); + + String percentileName = "percentile"; + double percentileValue = 1; + InternalTDigestPercentiles percentiles = mock(InternalTDigestPercentiles.class); + Iterator percentilesIterator = mock(Iterator.class); + Percentile percentile = mock(Percentile.class); + when(percentiles.iterator()).thenReturn(percentilesIterator); + when(percentilesIterator.hasNext()).thenReturn(true); + when(percentilesIterator.next()).thenReturn(percentile); + when(percentile.getValue()).thenReturn(percentileValue); + when(percentiles.getName()).thenReturn(percentileName); + + String missingName = "missing"; + Max missing = mock(Max.class); + when(missing.value()).thenReturn(Double.NaN); + when(missing.getName()).thenReturn(missingName); + + String infinityName = "infinity"; + Max infinity = mock(Max.class); + when(infinity.value()).thenReturn(Double.POSITIVE_INFINITY); + when(infinity.getName()).thenReturn(infinityName); + + String emptyName = "empty"; + InternalTDigestPercentiles empty = mock(InternalTDigestPercentiles.class); + Iterator emptyIterator = mock(Iterator.class); + when(empty.iterator()).thenReturn(emptyIterator); + when(emptyIterator.hasNext()).thenReturn(false); + when(empty.getName()).thenReturn(emptyName); + + return new Object[] { + new Object[] { asList(max), asList(maxName), new double[] { maxValue }, }, + new Object[] { asList(percentiles), asList(percentileName), new double[] { percentileValue } }, + new Object[] { asList(missing), asList(missingName), null }, + new Object[] { asList(infinity), asList(infinityName), null }, + new Object[] { asList(max, percentiles), asList(maxName, percentileName), new double[] { maxValue, percentileValue } }, + new Object[] { asList(max, percentiles), asList(percentileName, maxName), new double[] { percentileValue, maxValue } }, + new Object[] { asList(max, percentiles, missing), asList(maxName, percentileName, missingName), null }, }; + } + + private Object[] getFeaturesForSampledPeriodsData() { + long endTime = 300_000; + int maxStride = 4; + return new Object[] { + + // No data + + new Object[] { new Long[0][0], new double[0][0], endTime, 1, 1, Optional.empty() }, + + // 1 data point + + new Object[] { + new Long[][] { { 240_000L, 300_000L } }, + new double[][] { { 1, 2 } }, + endTime, + 1, + 1, + Optional.of(new SimpleEntry<>(new double[][] { { 1, 2 } }, 1)) }, + + new Object[] { + new Long[][] { { 240_000L, 300_000L } }, + new double[][] { { 1, 2 } }, + endTime, + 1, + 3, + Optional.of(new SimpleEntry<>(new double[][] { { 1, 2 } }, 1)) }, + + // 2 data points + + new Object[] { + new Long[][] { { 180_000L, 240_000L }, { 240_000L, 300_000L } }, + new double[][] { { 1, 2 }, { 2, 4 } }, + endTime, + 1, + 2, + Optional.of(new SimpleEntry<>(new double[][] { { 1, 2 }, { 2, 4 } }, 1)) }, + + new Object[] { + new Long[][] { { 180_000L, 240_000L }, { 240_000L, 300_000L } }, + new double[][] { { 1, 2 }, { 2, 4 } }, + endTime, + 1, + 1, + Optional.of(new SimpleEntry<>(new double[][] { { 2, 4 } }, 1)) }, + + new Object[] { + new Long[][] { { 180_000L, 240_000L }, { 240_000L, 300_000L } }, + new double[][] { { 1, 2 }, { 2, 4 } }, + endTime, + 4, + 2, + Optional.of(new SimpleEntry<>(new double[][] { { 1, 2 }, { 2, 4 } }, 1)) }, + + new Object[] { + new Long[][] { { 0L, 60_000L }, { 240_000L, 300_000L } }, + new double[][] { { 1, 2 }, { 2, 4 } }, + endTime, + 4, + 2, + Optional.of(new SimpleEntry<>(new double[][] { { 1, 2 }, { 2, 4 } }, 4)) }, + + // 5 data points + + new Object[] { + new Long[][] { + { 0L, 60_000L }, + { 60_000L, 120_000L }, + { 120_000L, 180_000L }, + { 180_000L, 240_000L }, + { 240_000L, 300_000L } }, + new double[][] { { 1, 2 }, { 3, 4 }, { 5, 6 }, { 7, 8 }, { 9, 10 } }, + endTime, + 4, + 10, + Optional.of(new SimpleEntry<>(new double[][] { { 1, 2 }, { 3, 4 }, { 5, 6 }, { 7, 8 }, { 9, 10 } }, 1)) }, + + new Object[] { + new Long[][] { { 0L, 60_000L }, { 60_000L, 120_000L }, { 180_000L, 240_000L }, { 240_000L, 300_000L } }, + new double[][] { { 1, 2 }, { 3, 4 }, { 7, 8 }, { 9, 10 } }, + endTime, + 4, + 10, + Optional.of(new SimpleEntry<>(new double[][] { { 1, 2 }, { 3, 4 }, { 5, 6 }, { 7, 8 }, { 9, 10 } }, 1)) }, + + new Object[] { + new Long[][] { { 0L, 60_000L }, { 120_000L, 180_000L }, { 240_000L, 300_000L } }, + new double[][] { { 1, 2 }, { 5, 6 }, { 9, 10 } }, + endTime, + 4, + 10, + Optional.of(new SimpleEntry<>(new double[][] { { 1, 2 }, { 3, 4 }, { 5, 6 }, { 7, 8 }, { 9, 10 } }, 1)) }, + + new Object[] { + new Long[][] { { 0L, 60_000L }, { 240_000L, 300_000L } }, + new double[][] { { 1, 2 }, { 9, 10 } }, + endTime, + 4, + 10, + Optional.of(new SimpleEntry<>(new double[][] { { 1, 2 }, { 3, 4 }, { 5, 6 }, { 7, 8 }, { 9, 10 } }, 1)) }, }; + } +} diff --git a/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoTests.java b/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoTests.java index de75212fa..cf18b2fdd 100644 --- a/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoTests.java +++ b/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoTests.java @@ -57,7 +57,6 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.action.search.ShardSearchFailure; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; @@ -91,6 +90,7 @@ import org.opensearch.search.aggregations.metrics.Percentile; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.dataprocessor.Imputer; import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; @@ -163,7 +163,7 @@ public void setup() throws Exception { imputer = new LinearUniformImputer(false); ExecutorService executorService = mock(ExecutorService.class); - when(threadPool.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); + when(threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); doAnswer(invocation -> { Runnable runnable = invocation.getArgument(0); runnable.run(); diff --git a/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoTests.java-e b/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoTests.java-e new file mode 100644 index 000000000..cf18b2fdd --- /dev/null +++ b/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoTests.java-e @@ -0,0 +1,388 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.feature; + +import static java.util.Arrays.asList; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.AnyOf.anyOf; +import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.time.Clock; +import java.time.ZoneId; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.concurrent.ExecutorService; + +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionFuture; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.MultiSearchRequest; +import org.opensearch.action.search.MultiSearchResponse; +import org.opensearch.action.search.MultiSearchResponse.Item; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.time.DateFormatter; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.mapper.DateFieldMapper; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.script.ScriptService; +import org.opensearch.script.TemplateScript; +import org.opensearch.script.TemplateScript.Factory; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation; +import org.opensearch.search.aggregations.metrics.InternalMin; +import org.opensearch.search.aggregations.metrics.InternalTDigestPercentiles; +import org.opensearch.search.aggregations.metrics.Max; +import org.opensearch.search.aggregations.metrics.MaxAggregationBuilder; +import org.opensearch.search.aggregations.metrics.MinAggregationBuilder; +import org.opensearch.search.aggregations.metrics.Percentile; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.dataprocessor.Imputer; +import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.util.ParseUtils; +import org.powermock.api.mockito.PowerMockito; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +@PowerMockIgnore("javax.management.*") +@RunWith(PowerMockRunner.class) +@PrepareForTest({ ParseUtils.class }) +public class SearchFeatureDaoTests { + private SearchFeatureDao searchFeatureDao; + + @Mock + private Client client; + @Mock + private ScriptService scriptService; + @Mock + private NamedXContentRegistry xContent; + private SecurityClientUtil clientUtil; + + @Mock + private Factory factory; + @Mock + private TemplateScript templateScript; + @Mock + private ActionFuture searchResponseFuture; + @Mock + private ActionFuture multiSearchResponseFuture; + @Mock + private SearchResponse searchResponse; + @Mock + private MultiSearchResponse multiSearchResponse; + @Mock + private Item multiSearchResponseItem; + @Mock + private Aggregations aggs; + @Mock + private Max max; + @Mock + private NodeStateManager stateManager; + + @Mock + private AnomalyDetector detector; + + @Mock + private ThreadPool threadPool; + + @Mock + private Clock clock; + + private SearchRequest searchRequest; + private SearchSourceBuilder searchSourceBuilder; + private MultiSearchRequest multiSearchRequest; + private Map aggsMap; + private IntervalTimeConfiguration detectionInterval; + private String detectorId; + private Imputer imputer; + private Settings settings; + + @Before + public void setup() throws Exception { + MockitoAnnotations.initMocks(this); + PowerMockito.mockStatic(ParseUtils.class); + + imputer = new LinearUniformImputer(false); + + ExecutorService executorService = mock(ExecutorService.class); + when(threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(executorService).execute(any(Runnable.class)); + + settings = Settings.EMPTY; + + when(client.threadPool()).thenReturn(threadPool); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + clientUtil = new SecurityClientUtil(nodeStateManager, settings); + searchFeatureDao = spy( + new SearchFeatureDao(client, xContent, imputer, clientUtil, settings, null, AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE) + ); + + detectionInterval = new IntervalTimeConfiguration(1, ChronoUnit.MINUTES); + detectorId = "123"; + + when(detector.getId()).thenReturn(detectorId); + when(detector.getTimeField()).thenReturn("testTimeField"); + when(detector.getIndices()).thenReturn(Arrays.asList("testIndices")); + when(detector.getInterval()).thenReturn(detectionInterval); + when(detector.getFilterQuery()).thenReturn(QueryBuilders.matchAllQuery()); + when(detector.getCategoryFields()).thenReturn(Collections.singletonList("a")); + + searchSourceBuilder = SearchSourceBuilder + .fromXContent(XContentType.JSON.xContent().createParser(xContent, LoggingDeprecationHandler.INSTANCE, "{}")); + searchRequest = new SearchRequest(detector.getIndices().toArray(new String[0])); + aggsMap = new HashMap<>(); + + when(max.getName()).thenReturn(CommonName.AGG_NAME_MAX_TIME); + List list = new ArrayList<>(); + list.add(max); + Aggregations aggregations = new Aggregations(list); + SearchHits hits = new SearchHits(new SearchHit[0], new TotalHits(1L, TotalHits.Relation.EQUAL_TO), 1f); + when(searchResponse.getHits()).thenReturn(hits); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + listener.onResponse(searchResponse); + return null; + }).when(client).search(eq(searchRequest), any()); + when(searchResponse.getAggregations()).thenReturn(aggregations); + + multiSearchRequest = new MultiSearchRequest(); + SearchRequest request = new SearchRequest(detector.getIndices().toArray(new String[0])); + multiSearchRequest.add(request); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + listener.onResponse(multiSearchResponse); + return null; + }).when(client).multiSearch(eq(multiSearchRequest), any()); + when(multiSearchResponse.getResponses()).thenReturn(new Item[] { multiSearchResponseItem }); + when(multiSearchResponseItem.getResponse()).thenReturn(searchResponse); + + // gson = PowerMockito.mock(Gson.class); + } + + @SuppressWarnings("unchecked") + private Object[] getFeaturesForPeriodThrowIllegalStateData() { + String aggName = "aggName"; + + InternalTDigestPercentiles empty = mock(InternalTDigestPercentiles.class); + Iterator emptyIterator = mock(Iterator.class); + when(empty.iterator()).thenReturn(emptyIterator); + when(emptyIterator.hasNext()).thenReturn(false); + when(empty.getName()).thenReturn(aggName); + + MultiBucketsAggregation multiBucket = mock(MultiBucketsAggregation.class); + when(multiBucket.getName()).thenReturn(aggName); + + return new Object[] { + new Object[] { asList(empty), asList(aggName), null }, + new Object[] { asList(multiBucket), asList(aggName), null }, }; + } + + @Test + @SuppressWarnings("unchecked") + public void getLatestDataTime_returnExpectedToListener() { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .aggregation(AggregationBuilders.max(CommonName.AGG_NAME_MAX_TIME).field(detector.getTimeField())) + .size(0); + searchRequest.source(searchSourceBuilder); + long epochTime = 100L; + aggsMap.put(CommonName.AGG_NAME_MAX_TIME, max); + when(max.getValue()).thenReturn((double) epochTime); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(eq(searchRequest), any(ActionListener.class)); + + when(ParseUtils.getLatestDataTime(eq(searchResponse))).thenReturn(Optional.of(epochTime)); + ActionListener> listener = mock(ActionListener.class); + searchFeatureDao.getLatestDataTime(detector, listener); + + ArgumentCaptor> captor = ArgumentCaptor.forClass(Optional.class); + verify(listener).onResponse(captor.capture()); + Optional result = captor.getValue(); + assertEquals(epochTime, result.get().longValue()); + } + + @Test + @SuppressWarnings("unchecked") + public void getFeaturesForSampledPeriods_throwToListener_whenSamplingFail() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener.onFailure(new RuntimeException()); + return null; + }).when(searchFeatureDao).getFeaturesForPeriod(any(), anyLong(), anyLong(), any(ActionListener.class)); + + ActionListener>> listener = mock(ActionListener.class); + searchFeatureDao.getFeaturesForSampledPeriods(detector, 1, 1, 0, listener); + + verify(listener).onFailure(any(Exception.class)); + } + + @Test + @SuppressWarnings("unchecked") + public void getFeaturesForPeriod_throwToListener_whenResponseParsingFails() throws Exception { + + long start = 100L; + long end = 200L; + when(ParseUtils.generateInternalFeatureQuery(eq(detector), eq(start), eq(end), eq(xContent))).thenReturn(searchSourceBuilder); + when(detector.getEnabledFeatureIds()).thenReturn(null); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(eq(searchRequest), any(ActionListener.class)); + + ActionListener> listener = mock(ActionListener.class); + searchFeatureDao.getFeaturesForPeriod(detector, start, end, listener); + + verify(listener).onFailure(any(Exception.class)); + } + + @Test + @SuppressWarnings("unchecked") + public void getFeaturesForPeriod_throwToListener_whenSearchFails() throws Exception { + + long start = 100L; + long end = 200L; + when(ParseUtils.generateInternalFeatureQuery(eq(detector), eq(start), eq(end), eq(xContent))).thenReturn(searchSourceBuilder); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException()); + return null; + }).when(client).search(eq(searchRequest), any(ActionListener.class)); + + ActionListener> listener = mock(ActionListener.class); + searchFeatureDao.getFeaturesForPeriod(detector, start, end, listener); + + verify(listener).onFailure(any(Exception.class)); + } + + @SuppressWarnings("unchecked") + @Test + public void testGetEntityMinDataTime() { + // simulate response {"took":11,"timed_out":false,"_shards":{"total":1, + // "successful":1,"skipped":0,"failed":0},"hits":{"max_score":null,"hits":[]}, + // "aggregations":{"min_timefield":{"value":1.602211285E12, + // "value_as_string":"2020-10-09T02:41:25.000Z"}, + // "max_timefield":{"value":1.602348325E12,"value_as_string":"2020-10-10T16:45:25.000Z"}}} + DocValueFormat dateFormat = new DocValueFormat.DateTime( + DateFormatter.forPattern("strict_date_optional_time||epoch_millis"), + ZoneId.of("UTC"), + DateFieldMapper.Resolution.MILLISECONDS + ); + double earliest = 1.602211285E12; + InternalMin minInternal = new InternalMin("min_timefield", earliest, dateFormat, new HashMap<>()); + InternalAggregations internalAggregations = InternalAggregations.from(Arrays.asList(minInternal)); + SearchHits hits = new SearchHits(new SearchHit[] {}, null, Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections(hits, internalAggregations, null, false, false, null, 1); + + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 11, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + + doAnswer(invocation -> { + SearchRequest request = invocation.getArgument(0); + assertEquals(1, request.indices().length); + assertTrue(detector.getIndices().contains(request.indices()[0])); + AggregatorFactories.Builder aggs = request.source().aggregations(); + assertEquals(1, aggs.count()); + Collection factory = aggs.getAggregatorFactories(); + assertTrue(!factory.isEmpty()); + Iterator iterator = factory.iterator(); + while (iterator.hasNext()) { + assertThat(iterator.next(), anyOf(instanceOf(MaxAggregationBuilder.class), instanceOf(MinAggregationBuilder.class))); + } + + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), any(ActionListener.class)); + + ActionListener> listener = mock(ActionListener.class); + Entity entity = Entity.createSingleAttributeEntity("field", "app_1"); + searchFeatureDao.getEntityMinDataTime(detector, entity, listener); + + ArgumentCaptor> captor = ArgumentCaptor.forClass(Optional.class); + verify(listener).onResponse(captor.capture()); + Optional result = captor.getValue(); + assertEquals((long) earliest, result.get().longValue()); + } +} diff --git a/src/test/java/org/opensearch/ad/indices/AnomalyDetectionIndicesTests.java b/src/test/java/org/opensearch/ad/indices/AnomalyDetectionIndicesTests.java index d67dd1b01..313800385 100644 --- a/src/test/java/org/opensearch/ad/indices/AnomalyDetectionIndicesTests.java +++ b/src/test/java/org/opensearch/ad/indices/AnomalyDetectionIndicesTests.java @@ -16,12 +16,12 @@ import java.util.Collections; import org.junit.Before; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.plugins.Plugin; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.indices.IndexManagementIntegTestCase; import org.opensearch.timeseries.settings.TimeSeriesSettings; @@ -33,11 +33,12 @@ public class AnomalyDetectionIndicesTests extends IndexManagementIntegTestCase> nodePlugins() { - return Collections.singletonList(AnomalyDetectorPlugin.class); + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); } @Before diff --git a/src/test/java/org/opensearch/ad/indices/AnomalyDetectionIndicesTests.java-e b/src/test/java/org/opensearch/ad/indices/AnomalyDetectionIndicesTests.java-e new file mode 100644 index 000000000..313800385 --- /dev/null +++ b/src/test/java/org/opensearch/ad/indices/AnomalyDetectionIndicesTests.java-e @@ -0,0 +1,161 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.indices; + +import java.io.IOException; +import java.util.Collection; +import java.util.Collections; + +import org.junit.Before; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.plugins.Plugin; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.indices.IndexManagementIntegTestCase; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +public class AnomalyDetectionIndicesTests extends IndexManagementIntegTestCase { + + private ADIndexManagement indices; + private Settings settings; + private DiscoveryNodeFilterer nodeFilter; + + // help register setting using TimeSeriesAnalyticsPlugin.getSettings. + // Otherwise, ADIndexManagement's constructor would fail due to + // unregistered settings like AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD. + @Override + protected Collection> nodePlugins() { + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); + } + + @Before + public void setup() throws IOException { + settings = Settings + .builder() + .put("plugins.anomaly_detection.ad_result_history_rollover_period", TimeValue.timeValueHours(12)) + .put("plugins.anomaly_detection.ad_result_history_retention_period", TimeValue.timeValueHours(24)) + .put("plugins.anomaly_detection.ad_result_history_max_docs", 10000L) + .put("plugins.anomaly_detection.request_timeout", TimeValue.timeValueSeconds(10)) + .build(); + + nodeFilter = new DiscoveryNodeFilterer(clusterService()); + + indices = new ADIndexManagement( + client(), + clusterService(), + client().threadPool(), + settings, + nodeFilter, + TimeSeriesSettings.MAX_UPDATE_RETRY_TIMES + ); + } + + public void testAnomalyDetectorIndexNotExists() { + boolean exists = indices.doesConfigIndexExist(); + assertFalse(exists); + } + + public void testAnomalyDetectorIndexExists() throws IOException { + indices.initConfigIndexIfAbsent(TestHelpers.createActionListener(response -> { + boolean acknowledged = response.isAcknowledged(); + assertTrue(acknowledged); + }, failure -> { throw new RuntimeException("should not recreate index"); })); + TestHelpers.waitForIndexCreationToComplete(client(), CommonName.CONFIG_INDEX); + } + + public void testAnomalyDetectorIndexExistsAndNotRecreate() throws IOException { + indices + .initConfigIndexIfAbsent( + TestHelpers + .createActionListener( + response -> response.isAcknowledged(), + failure -> { throw new RuntimeException("should not recreate index"); } + ) + ); + TestHelpers.waitForIndexCreationToComplete(client(), CommonName.CONFIG_INDEX); + if (client().admin().indices().prepareExists(CommonName.CONFIG_INDEX).get().isExists()) { + indices + .initConfigIndexIfAbsent( + TestHelpers + .createActionListener( + response -> { throw new RuntimeException("should not recreate index " + CommonName.CONFIG_INDEX); }, + failure -> { throw new RuntimeException("should not recreate index " + CommonName.CONFIG_INDEX); } + ) + ); + } + } + + public void testAnomalyResultIndexNotExists() { + boolean exists = indices.doesDefaultResultIndexExist(); + assertFalse(exists); + } + + public void testAnomalyResultIndexExists() throws IOException { + indices.initDefaultResultIndexIfAbsent(TestHelpers.createActionListener(response -> { + boolean acknowledged = response.isAcknowledged(); + assertTrue(acknowledged); + }, failure -> { throw new RuntimeException("should not recreate index"); })); + TestHelpers.waitForIndexCreationToComplete(client(), ADCommonName.ANOMALY_RESULT_INDEX_ALIAS); + } + + public void testAnomalyResultIndexExistsAndNotRecreate() throws IOException { + indices + .initDefaultResultIndexIfAbsent( + TestHelpers + .createActionListener( + response -> logger.info("Acknowledged: " + response.isAcknowledged()), + failure -> { throw new RuntimeException("should not recreate index"); } + ) + ); + TestHelpers.waitForIndexCreationToComplete(client(), ADCommonName.ANOMALY_RESULT_INDEX_ALIAS); + if (client().admin().indices().prepareExists(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS).get().isExists()) { + indices + .initDefaultResultIndexIfAbsent( + TestHelpers + .createActionListener( + response -> { + throw new RuntimeException("should not recreate index " + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS); + }, + failure -> { + throw new RuntimeException("should not recreate index " + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, failure); + } + ) + ); + } + } + + public void testGetDetectionStateIndexMapping() throws IOException { + String detectorIndexMappings = ADIndexManagement.getConfigMappings(); + detectorIndexMappings = detectorIndexMappings + .substring(detectorIndexMappings.indexOf("\"properties\""), detectorIndexMappings.lastIndexOf("}")); + String detectionStateIndexMapping = ADIndexManagement.getStateMappings(); + assertTrue(detectionStateIndexMapping.contains(detectorIndexMappings)); + } + + public void testValidateCustomIndexForBackendJob() throws IOException, InterruptedException { + String resultMapping = ADIndexManagement.getResultMappings(); + + validateCustomIndexForBackendJob(indices, resultMapping); + } + + public void testValidateCustomIndexForBackendJobInvalidMapping() { + validateCustomIndexForBackendJobInvalidMapping(indices); + } + + public void testValidateCustomIndexForBackendJobNoIndex() { + validateCustomIndexForBackendJobNoIndex(indices); + } +} diff --git a/src/test/java/org/opensearch/ad/indices/CustomIndexTests.java-e b/src/test/java/org/opensearch/ad/indices/CustomIndexTests.java-e new file mode 100644 index 000000000..53bea9015 --- /dev/null +++ b/src/test/java/org/opensearch/ad/indices/CustomIndexTests.java-e @@ -0,0 +1,343 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.indices; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; + +import org.opensearch.Version; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +public class CustomIndexTests extends AbstractTimeSeriesTest { + ADIndexManagement adIndices; + Client client; + ClusterService clusterService; + DiscoveryNodeFilterer nodeFilter; + ClusterState clusterState; + String customIndexName; + ClusterName clusterName; + + @Override + public void setUp() throws Exception { + super.setUp(); + + client = mock(Client.class); + + clusterService = mock(ClusterService.class); + + clusterName = new ClusterName("test"); + + customIndexName = "opensearch-ad-plugin-result-a"; + // clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().build()).build(); + when(clusterService.state()).thenReturn(clusterState); + + Settings settings = Settings.EMPTY; + ClusterSettings clusterSettings = new ClusterSettings( + settings, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays + .asList( + AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD, + AnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD, + AnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD, + AnomalyDetectorSettings.AD_MAX_PRIMARY_SHARDS + ) + ) + ) + ); + + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + nodeFilter = mock(DiscoveryNodeFilterer.class); + + adIndices = new ADIndexManagement( + client, + clusterService, + threadPool, + settings, + nodeFilter, + TimeSeriesSettings.MAX_UPDATE_RETRY_TIMES + ); + } + + private Map createMapping() { + Map mappings = new HashMap<>(); + + Map grade_mapping = new HashMap<>(); + grade_mapping.put("type", "double"); + mappings.put(AnomalyResult.ANOMALY_GRADE_FIELD, grade_mapping); + + Map score_mapping = new HashMap<>(); + score_mapping.put("type", "double"); + mappings.put(AnomalyResult.ANOMALY_SCORE_FIELD, score_mapping); + + Map approx_mapping = new HashMap<>(); + approx_mapping.put("type", "date"); + approx_mapping.put("format", "strict_date_time||epoch_millis"); + mappings.put(AnomalyResult.APPROX_ANOMALY_START_FIELD, approx_mapping); + + Map confidence_mapping = new HashMap<>(); + confidence_mapping.put("type", "double"); + mappings.put(CommonName.CONFIDENCE_FIELD, confidence_mapping); + + Map data_end_time = new HashMap<>(); + data_end_time.put("type", "date"); + data_end_time.put("format", "strict_date_time||epoch_millis"); + mappings.put(CommonName.DATA_END_TIME_FIELD, data_end_time); + + Map data_start_time = new HashMap<>(); + data_start_time.put("type", "date"); + data_start_time.put("format", "strict_date_time||epoch_millis"); + mappings.put(CommonName.DATA_START_TIME_FIELD, data_start_time); + + Map exec_start_mapping = new HashMap<>(); + exec_start_mapping.put("type", "date"); + exec_start_mapping.put("format", "strict_date_time||epoch_millis"); + mappings.put(CommonName.EXECUTION_START_TIME_FIELD, exec_start_mapping); + + Map exec_end_mapping = new HashMap<>(); + exec_end_mapping.put("type", "date"); + exec_end_mapping.put("format", "strict_date_time||epoch_millis"); + mappings.put(CommonName.EXECUTION_END_TIME_FIELD, exec_end_mapping); + + Map detector_id_mapping = new HashMap<>(); + detector_id_mapping.put("type", "keyword"); + mappings.put(AnomalyResult.DETECTOR_ID_FIELD, detector_id_mapping); + + Map entity_mapping = new HashMap<>(); + entity_mapping.put("type", "nested"); + Map entity_nested_mapping = new HashMap<>(); + entity_nested_mapping.put("name", Collections.singletonMap("type", "keyword")); + entity_nested_mapping.put("value", Collections.singletonMap("type", "keyword")); + entity_mapping.put(CommonName.PROPERTIES, entity_nested_mapping); + mappings.put(CommonName.ENTITY_FIELD, entity_mapping); + + Map error_mapping = new HashMap<>(); + error_mapping.put("type", "text"); + mappings.put(CommonName.ERROR_FIELD, error_mapping); + + Map expected_mapping = new HashMap<>(); + expected_mapping.put("type", "nested"); + Map expected_nested_mapping = new HashMap<>(); + expected_mapping.put(CommonName.PROPERTIES, expected_nested_mapping); + expected_nested_mapping.put("likelihood", Collections.singletonMap("type", "double")); + Map value_list_mapping = new HashMap<>(); + expected_nested_mapping.put("value_list", value_list_mapping); + value_list_mapping.put("type", "nested"); + Map value_list_nested_mapping = new HashMap<>(); + value_list_mapping.put(CommonName.PROPERTIES, value_list_nested_mapping); + value_list_nested_mapping.put("data", Collections.singletonMap("type", "double")); + value_list_nested_mapping.put("feature_id", Collections.singletonMap("type", "keyword")); + mappings.put(AnomalyResult.EXPECTED_VALUES_FIELD, expected_mapping); + + Map feature_mapping = new HashMap<>(); + feature_mapping.put("type", "nested"); + Map feature_nested_mapping = new HashMap<>(); + feature_mapping.put(CommonName.PROPERTIES, feature_nested_mapping); + feature_nested_mapping.put("data", Collections.singletonMap("type", "double")); + feature_nested_mapping.put("feature_id", Collections.singletonMap("type", "keyword")); + mappings.put(CommonName.FEATURE_DATA_FIELD, feature_mapping); + mappings.put(AnomalyResult.IS_ANOMALY_FIELD, Collections.singletonMap("type", "boolean")); + mappings.put(CommonName.MODEL_ID_FIELD, Collections.singletonMap("type", "keyword")); + + Map past_mapping = new HashMap<>(); + past_mapping.put("type", "nested"); + Map past_nested_mapping = new HashMap<>(); + past_mapping.put(CommonName.PROPERTIES, past_nested_mapping); + past_nested_mapping.put("data", Collections.singletonMap("type", "double")); + past_nested_mapping.put("feature_id", Collections.singletonMap("type", "keyword")); + mappings.put(AnomalyResult.PAST_VALUES_FIELD, past_mapping); + + Map attribution_mapping = new HashMap<>(); + attribution_mapping.put("type", "nested"); + Map attribution_nested_mapping = new HashMap<>(); + attribution_mapping.put(CommonName.PROPERTIES, attribution_nested_mapping); + attribution_nested_mapping.put("data", Collections.singletonMap("type", "double")); + attribution_nested_mapping.put("feature_id", Collections.singletonMap("type", "keyword")); + mappings.put(AnomalyResult.RELEVANT_ATTRIBUTION_FIELD, attribution_mapping); + + mappings.put(CommonName.SCHEMA_VERSION_FIELD, Collections.singletonMap("type", "integer")); + + mappings.put(CommonName.TASK_ID_FIELD, Collections.singletonMap("type", "keyword")); + + mappings.put(AnomalyResult.THRESHOLD_FIELD, Collections.singletonMap("type", "double")); + + Map user_mapping = new HashMap<>(); + user_mapping.put("type", "nested"); + Map user_nested_mapping = new HashMap<>(); + user_mapping.put(CommonName.PROPERTIES, user_nested_mapping); + Map backend_role_mapping = new HashMap<>(); + backend_role_mapping.put("type", "text"); + backend_role_mapping.put("fields", Collections.singletonMap("keyword", Collections.singletonMap("type", "keyword"))); + user_nested_mapping.put("backend_roles", backend_role_mapping); + Map custom_attribute_mapping = new HashMap<>(); + custom_attribute_mapping.put("type", "text"); + custom_attribute_mapping.put("fields", Collections.singletonMap("keyword", Collections.singletonMap("type", "keyword"))); + user_nested_mapping.put("custom_attribute_names", custom_attribute_mapping); + Map name_mapping = new HashMap<>(); + name_mapping.put("type", "text"); + Map name_fields_mapping = new HashMap<>(); + name_fields_mapping.put("type", "keyword"); + name_fields_mapping.put("ignore_above", 256); + name_mapping.put("fields", Collections.singletonMap("keyword", name_fields_mapping)); + user_nested_mapping.put("name", name_mapping); + Map roles_mapping = new HashMap<>(); + roles_mapping.put("type", "text"); + roles_mapping.put("fields", Collections.singletonMap("keyword", Collections.singletonMap("type", "keyword"))); + user_nested_mapping.put("roles", roles_mapping); + mappings.put(CommonName.USER_FIELD, user_mapping); + return mappings; + } + + public void testCorrectMapping() throws IOException { + Map mappings = createMapping(); + + IndexMetadata indexMetadata1 = new IndexMetadata.Builder(customIndexName) + .settings( + Settings + .builder() + .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + ) + .putMapping(new MappingMetadata("type1", Collections.singletonMap(CommonName.PROPERTIES, mappings))) + .build(); + when(clusterService.state()) + .thenReturn(ClusterState.builder(clusterName).metadata(Metadata.builder().put(indexMetadata1, true).build()).build()); + + assertTrue(adIndices.isValidResultIndexMapping(customIndexName)); + } + + /** + * Test that the mapping returned by get mapping request returns the same mapping + * but with different order + * @throws IOException when MappingMetadata constructor throws errors + */ + public void testCorrectReordered() throws IOException { + Map mappings = createMapping(); + + Map feature_mapping = new HashMap<>(); + feature_mapping.put("type", "nested"); + Map feature_nested_mapping = new HashMap<>(); + feature_mapping.put(CommonName.PROPERTIES, feature_nested_mapping); + // feature_id comes before data compared with what createMapping returned + feature_nested_mapping.put("feature_id", Collections.singletonMap("type", "keyword")); + feature_nested_mapping.put("data", Collections.singletonMap("type", "double")); + mappings.put(CommonName.FEATURE_DATA_FIELD, feature_mapping); + + IndexMetadata indexMetadata1 = new IndexMetadata.Builder(customIndexName) + .settings( + Settings + .builder() + .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + ) + .putMapping(new MappingMetadata("type1", Collections.singletonMap(CommonName.PROPERTIES, mappings))) + .build(); + when(clusterService.state()) + .thenReturn(ClusterState.builder(clusterName).metadata(Metadata.builder().put(indexMetadata1, true).build()).build()); + + assertTrue(adIndices.isValidResultIndexMapping(customIndexName)); + } + + /** + * Test that the mapping returned by get mapping request returns a super set + * of result index mapping + * @throws IOException when MappingMetadata constructor throws errors + */ + public void testSuperset() throws IOException { + Map mappings = createMapping(); + + Map feature_mapping = new HashMap<>(); + feature_mapping.put("type", "nested"); + Map feature_nested_mapping = new HashMap<>(); + feature_mapping.put(CommonName.PROPERTIES, feature_nested_mapping); + feature_nested_mapping.put("feature_id", Collections.singletonMap("type", "keyword")); + feature_nested_mapping.put("data", Collections.singletonMap("type", "double")); + mappings.put("a", feature_mapping); + + IndexMetadata indexMetadata1 = new IndexMetadata.Builder(customIndexName) + .settings( + Settings + .builder() + .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + ) + .putMapping(new MappingMetadata("type1", Collections.singletonMap(CommonName.PROPERTIES, mappings))) + .build(); + when(clusterService.state()) + .thenReturn(ClusterState.builder(clusterName).metadata(Metadata.builder().put(indexMetadata1, true).build()).build()); + + assertTrue(adIndices.isValidResultIndexMapping(customIndexName)); + } + + public void testInCorrectMapping() throws IOException { + Map mappings = new HashMap<>(); + + Map past_mapping = new HashMap<>(); + past_mapping.put("type", "nested"); + Map past_nested_mapping = new HashMap<>(); + past_mapping.put(CommonName.PROPERTIES, past_nested_mapping); + past_nested_mapping.put("data", Collections.singletonMap("type", "double")); + past_nested_mapping.put("feature_id", Collections.singletonMap("type", "keyword")); + mappings.put(AnomalyResult.PAST_VALUES_FIELD, past_mapping); + + Map attribution_mapping = new HashMap<>(); + past_mapping.put("type", "nested"); + Map attribution_nested_mapping = new HashMap<>(); + attribution_mapping.put(CommonName.PROPERTIES, attribution_nested_mapping); + attribution_nested_mapping.put("data", Collections.singletonMap("type", "double")); + attribution_nested_mapping.put("feature_id", Collections.singletonMap("type", "keyword")); + mappings.put(AnomalyResult.RELEVANT_ATTRIBUTION_FIELD, attribution_mapping); + + IndexMetadata indexMetadata1 = new IndexMetadata.Builder(customIndexName) + .settings( + Settings + .builder() + .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + ) + .putMapping(new MappingMetadata("type1", Collections.singletonMap(CommonName.PROPERTIES, mappings))) + .build(); + when(clusterService.state()) + .thenReturn(ClusterState.builder(clusterName).metadata(Metadata.builder().put(indexMetadata1, true).build()).build()); + + assertTrue(!adIndices.isValidResultIndexMapping(customIndexName)); + } + +} diff --git a/src/test/java/org/opensearch/ad/indices/InitAnomalyDetectionIndicesTests.java-e b/src/test/java/org/opensearch/ad/indices/InitAnomalyDetectionIndicesTests.java-e new file mode 100644 index 000000000..9e3c169c2 --- /dev/null +++ b/src/test/java/org/opensearch/ad/indices/InitAnomalyDetectionIndicesTests.java-e @@ -0,0 +1,229 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.indices; + +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; + +import org.mockito.ArgumentCaptor; +import org.opensearch.action.ActionListener; +import org.opensearch.action.admin.indices.alias.Alias; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +public class InitAnomalyDetectionIndicesTests extends AbstractTimeSeriesTest { + Client client; + ClusterService clusterService; + ThreadPool threadPool; + Settings settings; + DiscoveryNodeFilterer nodeFilter; + ADIndexManagement adIndices; + ClusterName clusterName; + ClusterState clusterState; + IndicesAdminClient indicesClient; + int numberOfHotNodes; + + @Override + public void setUp() throws Exception { + super.setUp(); + + client = mock(Client.class); + indicesClient = mock(IndicesAdminClient.class); + AdminClient adminClient = mock(AdminClient.class); + when(client.admin()).thenReturn(adminClient); + when(adminClient.indices()).thenReturn(indicesClient); + + clusterService = mock(ClusterService.class); + threadPool = mock(ThreadPool.class); + + numberOfHotNodes = 4; + nodeFilter = mock(DiscoveryNodeFilterer.class); + when(nodeFilter.getNumberOfEligibleDataNodes()).thenReturn(numberOfHotNodes); + + Settings settings = Settings.EMPTY; + ClusterSettings clusterSettings = new ClusterSettings( + settings, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays + .asList( + AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD, + AnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD, + AnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD, + AnomalyDetectorSettings.AD_MAX_PRIMARY_SHARDS + ) + ) + ) + ); + + clusterName = new ClusterName("test"); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().build()).build(); + when(clusterService.state()).thenReturn(clusterState); + + adIndices = new ADIndexManagement( + client, + clusterService, + threadPool, + settings, + nodeFilter, + TimeSeriesSettings.MAX_UPDATE_RETRY_TIMES + ); + } + + @SuppressWarnings("unchecked") + private void fixedPrimaryShardsIndexCreationTemplate(String index) throws IOException { + doAnswer(invocation -> { + CreateIndexRequest request = invocation.getArgument(0); + assertEquals(index, request.index()); + + ActionListener listener = (ActionListener) invocation.getArgument(1); + + listener.onResponse(new CreateIndexResponse(true, true, index)); + return null; + }).when(indicesClient).create(any(), any()); + + ActionListener listener = mock(ActionListener.class); + if (index.equals(CommonName.CONFIG_INDEX)) { + adIndices.initConfigIndexIfAbsent(listener); + } else { + adIndices.initStateIndex(listener); + } + + ArgumentCaptor captor = ArgumentCaptor.forClass(CreateIndexResponse.class); + verify(listener).onResponse(captor.capture()); + CreateIndexResponse result = captor.getValue(); + assertEquals(index, result.index()); + } + + @SuppressWarnings("unchecked") + private void fixedPrimaryShardsIndexNoCreationTemplate(String index, String... alias) throws IOException { + clusterState = mock(ClusterState.class); + when(clusterService.state()).thenReturn(clusterState); + + // RoutingTable.Builder rb = RoutingTable.builder(); + // rb.addAsNew(indexMeta(index, 1L)); + // when(clusterState.metadata()).thenReturn(rb.build()); + + Metadata.Builder mb = Metadata.builder(); + // mb.put(indexMeta(".opendistro-anomaly-results-history-2020.06.24-000003", 1L, ADCommonName.ANOMALY_RESULT_INDEX_ALIAS), true); + mb.put(indexMeta(index, 1L, alias), true); + when(clusterState.metadata()).thenReturn(mb.build()); + + ActionListener listener = mock(ActionListener.class); + if (index.equals(CommonName.CONFIG_INDEX)) { + adIndices.initConfigIndexIfAbsent(listener); + } else { + adIndices.initDefaultResultIndexIfAbsent(listener); + } + + verify(indicesClient, never()).create(any(), any()); + } + + @SuppressWarnings("unchecked") + private void adaptivePrimaryShardsIndexCreationTemplate(String index) throws IOException { + + doAnswer(invocation -> { + CreateIndexRequest request = invocation.getArgument(0); + if (index.equals(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS)) { + assertTrue(request.aliases().contains(new Alias(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS))); + } else { + assertEquals(index, request.index()); + } + + Settings settings = request.settings(); + if (index.equals(CommonName.JOB_INDEX)) { + assertThat(settings.get("index.number_of_shards"), equalTo(Integer.toString(1))); + } else { + assertThat(settings.get("index.number_of_shards"), equalTo(Integer.toString(numberOfHotNodes))); + } + + ActionListener listener = (ActionListener) invocation.getArgument(1); + + listener.onResponse(new CreateIndexResponse(true, true, index)); + return null; + }).when(indicesClient).create(any(), any()); + + ActionListener listener = mock(ActionListener.class); + if (index.equals(CommonName.CONFIG_INDEX)) { + adIndices.initConfigIndexIfAbsent(listener); + } else if (index.equals(ADCommonName.DETECTION_STATE_INDEX)) { + adIndices.initStateIndex(listener); + } else if (index.equals(ADCommonName.CHECKPOINT_INDEX_NAME)) { + adIndices.initCheckpointIndex(listener); + } else if (index.equals(CommonName.JOB_INDEX)) { + adIndices.initJobIndex(listener); + } else { + adIndices.initDefaultResultIndexIfAbsent(listener); + } + + ArgumentCaptor captor = ArgumentCaptor.forClass(CreateIndexResponse.class); + verify(listener).onResponse(captor.capture()); + CreateIndexResponse result = captor.getValue(); + assertEquals(index, result.index()); + } + + public void testNotCreateDetector() throws IOException { + fixedPrimaryShardsIndexNoCreationTemplate(CommonName.CONFIG_INDEX); + } + + public void testNotCreateResult() throws IOException { + fixedPrimaryShardsIndexNoCreationTemplate(CommonName.CONFIG_INDEX); + } + + public void testCreateDetector() throws IOException { + fixedPrimaryShardsIndexCreationTemplate(CommonName.CONFIG_INDEX); + } + + public void testCreateState() throws IOException { + fixedPrimaryShardsIndexCreationTemplate(ADCommonName.DETECTION_STATE_INDEX); + } + + public void testCreateJob() throws IOException { + adaptivePrimaryShardsIndexCreationTemplate(CommonName.JOB_INDEX); + } + + public void testCreateResult() throws IOException { + adaptivePrimaryShardsIndexCreationTemplate(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS); + } + + public void testCreateCheckpoint() throws IOException { + adaptivePrimaryShardsIndexCreationTemplate(ADCommonName.CHECKPOINT_INDEX_NAME); + } +} diff --git a/src/test/java/org/opensearch/ad/indices/RolloverTests.java-e b/src/test/java/org/opensearch/ad/indices/RolloverTests.java-e new file mode 100644 index 000000000..abb853aef --- /dev/null +++ b/src/test/java/org/opensearch/ad/indices/RolloverTests.java-e @@ -0,0 +1,249 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.indices; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.admin.cluster.state.ClusterStateRequest; +import org.opensearch.action.admin.cluster.state.ClusterStateResponse; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.rollover.Condition; +import org.opensearch.action.admin.indices.rollover.MaxDocsCondition; +import org.opensearch.action.admin.indices.rollover.RolloverRequest; +import org.opensearch.action.admin.indices.rollover.RolloverResponse; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +public class RolloverTests extends AbstractTimeSeriesTest { + private ADIndexManagement adIndices; + private IndicesAdminClient indicesClient; + private ClusterAdminClient clusterAdminClient; + private ClusterName clusterName; + private ClusterState clusterState; + private ClusterService clusterService; + private long defaultMaxDocs; + private int numberOfNodes; + + @Override + public void setUp() throws Exception { + super.setUp(); + Client client = mock(Client.class); + indicesClient = mock(IndicesAdminClient.class); + AdminClient adminClient = mock(AdminClient.class); + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays + .asList( + AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD, + AnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD, + AnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD, + AnomalyDetectorSettings.AD_MAX_PRIMARY_SHARDS + ) + ) + ) + ); + + clusterName = new ClusterName("test"); + + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + ThreadPool threadPool = mock(ThreadPool.class); + Settings settings = Settings.EMPTY; + when(client.admin()).thenReturn(adminClient); + when(adminClient.indices()).thenReturn(indicesClient); + + DiscoveryNodeFilterer nodeFilter = mock(DiscoveryNodeFilterer.class); + numberOfNodes = 2; + when(nodeFilter.getNumberOfEligibleDataNodes()).thenReturn(numberOfNodes); + + adIndices = new ADIndexManagement( + client, + clusterService, + threadPool, + settings, + nodeFilter, + TimeSeriesSettings.MAX_UPDATE_RETRY_TIMES + ); + + clusterAdminClient = mock(ClusterAdminClient.class); + when(adminClient.cluster()).thenReturn(clusterAdminClient); + + doAnswer(invocation -> { + ClusterStateRequest clusterStateRequest = invocation.getArgument(0); + assertEquals(ADIndexManagement.ALL_AD_RESULTS_INDEX_PATTERN, clusterStateRequest.indices()[0]); + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArgument(1); + listener.onResponse(new ClusterStateResponse(clusterName, clusterState, true)); + return null; + }).when(clusterAdminClient).state(any(), any()); + + defaultMaxDocs = AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD.getDefault(Settings.EMPTY); + } + + private void assertRolloverRequest(RolloverRequest request) { + assertEquals(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, request.indices()[0]); + + Map> conditions = request.getConditions(); + assertEquals(1, conditions.size()); + assertEquals(new MaxDocsCondition(defaultMaxDocs * numberOfNodes), conditions.get(MaxDocsCondition.NAME)); + + CreateIndexRequest createIndexRequest = request.getCreateIndexRequest(); + assertEquals(ADIndexManagement.AD_RESULT_HISTORY_INDEX_PATTERN, createIndexRequest.index()); + assertTrue(createIndexRequest.mappings().contains("data_start_time")); + } + + public void testNotRolledOver() { + doAnswer(invocation -> { + RolloverRequest request = invocation.getArgument(0); + assertRolloverRequest(request); + + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArgument(1); + + listener.onResponse(new RolloverResponse(null, null, Collections.emptyMap(), request.isDryRun(), false, true, true)); + return null; + }).when(indicesClient).rolloverIndex(any(), any()); + + Metadata.Builder metaBuilder = Metadata + .builder() + .put(indexMeta(".opendistro-anomaly-results-history-2020.06.24-000003", 1L, ADCommonName.ANOMALY_RESULT_INDEX_ALIAS), true); + clusterState = ClusterState.builder(clusterName).metadata(metaBuilder.build()).build(); + when(clusterService.state()).thenReturn(clusterState); + + adIndices.rolloverAndDeleteHistoryIndex(); + verify(clusterAdminClient, never()).state(any(), any()); + verify(indicesClient, times(1)).rolloverIndex(any(), any()); + } + + private void setUpRolloverSuccess() { + doAnswer(invocation -> { + RolloverRequest request = invocation.getArgument(0); + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArgument(1); + + assertEquals(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, request.indices()[0]); + + Map> conditions = request.getConditions(); + assertEquals(1, conditions.size()); + assertEquals(new MaxDocsCondition(defaultMaxDocs * numberOfNodes), conditions.get(MaxDocsCondition.NAME)); + + CreateIndexRequest createIndexRequest = request.getCreateIndexRequest(); + assertEquals(ADIndexManagement.AD_RESULT_HISTORY_INDEX_PATTERN, createIndexRequest.index()); + assertTrue(createIndexRequest.mappings().contains("data_start_time")); + listener.onResponse(new RolloverResponse(null, null, Collections.emptyMap(), request.isDryRun(), true, true, true)); + return null; + }).when(indicesClient).rolloverIndex(any(), any()); + } + + public void testRolledOverButNotDeleted() { + setUpRolloverSuccess(); + + Metadata.Builder metaBuilder = Metadata + .builder() + .put(indexMeta(".opendistro-anomaly-results-history-2020.06.24-000003", 1L, ADCommonName.ANOMALY_RESULT_INDEX_ALIAS), true) + .put( + indexMeta( + ".opendistro-anomaly-results-history-2020.06.24-000004", + Instant.now().toEpochMilli(), + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS + ), + true + ); + clusterState = ClusterState.builder(clusterName).metadata(metaBuilder.build()).build(); + when(clusterService.state()).thenReturn(clusterState); + + adIndices.rolloverAndDeleteHistoryIndex(); + verify(clusterAdminClient, times(1)).state(any(), any()); + verify(indicesClient, times(1)).rolloverIndex(any(), any()); + verify(indicesClient, never()).delete(any(), any()); + } + + private void setUpTriggerDelete() { + Metadata.Builder metaBuilder = Metadata + .builder() + .put(indexMeta(".opendistro-anomaly-results-history-2020.06.24-000002", 1L, ADCommonName.ANOMALY_RESULT_INDEX_ALIAS), true) + .put(indexMeta(".opendistro-anomaly-results-history-2020.06.24-000003", 2L, ADCommonName.ANOMALY_RESULT_INDEX_ALIAS), true) + .put( + indexMeta( + ".opendistro-anomaly-results-history-2020.06.24-000004", + Instant.now().toEpochMilli(), + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS + ), + true + ); + clusterState = ClusterState.builder(clusterName).metadata(metaBuilder.build()).build(); + when(clusterService.state()).thenReturn(clusterState); + } + + public void testRolledOverDeleted() { + setUpRolloverSuccess(); + setUpTriggerDelete(); + + adIndices.rolloverAndDeleteHistoryIndex(); + verify(clusterAdminClient, times(1)).state(any(), any()); + verify(indicesClient, times(1)).rolloverIndex(any(), any()); + verify(indicesClient, times(1)).delete(any(), any()); + } + + public void testRetryingDelete() { + setUpRolloverSuccess(); + setUpTriggerDelete(); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArgument(1); + + // group delete not acked, trigger retry. But retry also failed. + listener.onResponse(new AcknowledgedResponse(false)); + + return null; + }).when(indicesClient).delete(any(), any()); + + adIndices.rolloverAndDeleteHistoryIndex(); + verify(clusterAdminClient, times(1)).state(any(), any()); + verify(indicesClient, times(1)).rolloverIndex(any(), any()); + // 1 group delete, 1 separate retry for each index to delete + verify(indicesClient, times(2)).delete(any(), any()); + } +} diff --git a/src/test/java/org/opensearch/ad/indices/UpdateMappingTests.java-e b/src/test/java/org/opensearch/ad/indices/UpdateMappingTests.java-e new file mode 100644 index 000000000..f53393014 --- /dev/null +++ b/src/test/java/org/opensearch/ad/indices/UpdateMappingTests.java-e @@ -0,0 +1,318 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.indices; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; + +import org.junit.BeforeClass; +import org.mockito.ArgumentCaptor; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.admin.indices.alias.get.GetAliasesRequest; +import org.opensearch.action.admin.indices.alias.get.GetAliasesResponse; +import org.opensearch.action.admin.indices.settings.get.GetSettingsAction; +import org.opensearch.action.admin.indices.settings.get.GetSettingsResponse; +import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.AliasMetadata; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.routing.RoutingTable; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +public class UpdateMappingTests extends AbstractTimeSeriesTest { + private static String resultIndexName; + + private ADIndexManagement adIndices; + private ClusterService clusterService; + private int numberOfNodes; + private AdminClient adminClient; + private ClusterState clusterState; + private IndicesAdminClient indicesAdminClient; + private Client client; + private Settings settings; + private DiscoveryNodeFilterer nodeFilter; + + @BeforeClass + public static void setUpBeforeClass() { + resultIndexName = ".opendistro-anomaly-results-history-2020.06.24-000003"; + } + + @SuppressWarnings("unchecked") + @Override + public void setUp() throws Exception { + super.setUp(); + + client = mock(Client.class); + adminClient = mock(AdminClient.class); + when(client.admin()).thenReturn(adminClient); + indicesAdminClient = mock(IndicesAdminClient.class); + when(adminClient.indices()).thenReturn(indicesAdminClient); + + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays + .asList( + AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD, + AnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD, + AnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD, + AnomalyDetectorSettings.AD_MAX_PRIMARY_SHARDS + ) + ) + ) + ); + + clusterState = mock(ClusterState.class); + + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.state()).thenReturn(clusterState); + + Map openMap = new HashMap<>(); + Metadata metadata = Metadata.builder().indices(openMap).build(); + when(clusterState.getMetadata()).thenReturn(metadata); + when(clusterState.metadata()).thenReturn(metadata); + + RoutingTable routingTable = mock(RoutingTable.class); + when(clusterState.getRoutingTable()).thenReturn(routingTable); + when(routingTable.hasIndex(anyString())).thenReturn(true); + + settings = Settings.EMPTY; + nodeFilter = mock(DiscoveryNodeFilterer.class); + numberOfNodes = 2; + when(nodeFilter.getNumberOfEligibleDataNodes()).thenReturn(numberOfNodes); + adIndices = new ADIndexManagement( + client, + clusterService, + threadPool, + settings, + nodeFilter, + TimeSeriesSettings.MAX_UPDATE_RETRY_TIMES + ); + } + + public void testNoIndexToUpdate() { + adIndices.update(); + verify(indicesAdminClient, never()).putMapping(any(), any()); + // for an index, we may check doesAliasExists/doesIndexExists + verify(clusterService, times(5)).state(); + adIndices.update(); + // we will not trigger new check since we have checked all indices before + verify(clusterService, times(5)).state(); + } + + @SuppressWarnings({ "serial", "unchecked" }) + public void testUpdateMapping() throws IOException { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArgument(1); + + Map> builder = new HashMap<>(); + List aliasMetadata = new ArrayList<>(); + aliasMetadata.add(AliasMetadata.builder(ADIndex.RESULT.name()).build()); + builder.put(resultIndexName, aliasMetadata); + + listener.onResponse(new GetAliasesResponse(builder)); + return null; + }).when(indicesAdminClient).getAliases(any(GetAliasesRequest.class), any()); + + IndexMetadata indexMetadata = IndexMetadata + .builder(resultIndexName) + .putAlias(AliasMetadata.builder(ADIndex.RESULT.getIndexName())) + .settings(settings(Version.CURRENT)) + .numberOfShards(1) + .numberOfReplicas(0) + .putMapping(new MappingMetadata("type", new HashMap() { + { + put(ADIndexManagement.META, new HashMap() { + { + // version 1 will cause update + put(CommonName.SCHEMA_VERSION_FIELD, 1); + } + }); + } + })) + .build(); + Map openMapBuilder = new HashMap<>(); + openMapBuilder.put(resultIndexName, indexMetadata); + Metadata metadata = Metadata.builder().indices(openMapBuilder).build(); + when(clusterState.getMetadata()).thenReturn(metadata); + when(clusterState.metadata()).thenReturn(metadata); + adIndices.update(); + verify(indicesAdminClient, times(1)).putMapping(any(), any()); + } + + // since SETTING_AUTO_EXPAND_REPLICAS is set, we won't update + @SuppressWarnings("unchecked") + public void testJobSettingNoUpdate() { + Map indexToSettings = new HashMap<>(); + Settings jobSettings = Settings + .builder() + .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 2) + .put(IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS, "1-all") + .build(); + indexToSettings.put(ADIndex.JOB.getIndexName(), jobSettings); + GetSettingsResponse getSettingsResponse = new GetSettingsResponse(indexToSettings, new HashMap<>()); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArgument(2); + + listener.onResponse(getSettingsResponse); + return null; + }).when(client).execute(any(), any(), any()); + adIndices.update(); + verify(indicesAdminClient, never()).updateSettings(any(), any()); + } + + @SuppressWarnings("unchecked") + private void setUpSuccessfulGetJobSetting() { + Map indexToSettings = new HashMap<>(); + Settings jobSettings = Settings + .builder() + .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 2) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + .build(); + indexToSettings.put(ADIndex.JOB.getIndexName(), jobSettings); + GetSettingsResponse getSettingsResponse = new GetSettingsResponse(indexToSettings, new HashMap<>()); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArgument(2); + + listener.onResponse(getSettingsResponse); + return null; + }).when(client).execute(any(), any(), any()); + } + + // since SETTING_AUTO_EXPAND_REPLICAS is set, we won't update + @SuppressWarnings("unchecked") + public void testJobSettingUpdate() { + setUpSuccessfulGetJobSetting(); + ArgumentCaptor createIndexRequestCaptor = ArgumentCaptor.forClass(UpdateSettingsRequest.class); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(new AcknowledgedResponse(true)); + return null; + }).when(indicesAdminClient).updateSettings(createIndexRequestCaptor.capture(), any()); + adIndices.update(); + verify(client, times(1)).execute(eq(GetSettingsAction.INSTANCE), any(), any()); + verify(indicesAdminClient, times(1)).updateSettings(any(), any()); + UpdateSettingsRequest request = createIndexRequestCaptor.getValue(); + assertEquals("1-10", request.settings().get(IndexMetadata.SETTING_AUTO_EXPAND_REPLICAS)); + + adIndices.update(); + // won't have to do it again since we succeeded last time + verify(client, times(1)).execute(eq(GetSettingsAction.INSTANCE), any(), any()); + verify(indicesAdminClient, times(1)).updateSettings(any(), any()); + } + + // since SETTING_NUMBER_OF_SHARDS is not there, we skip updating + @SuppressWarnings("unchecked") + public void testMissingPrimaryJobShards() { + Map indexToSettings = new HashMap<>(); + Settings jobSettings = Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT).build(); + indexToSettings.put(ADIndex.JOB.getIndexName(), jobSettings); + GetSettingsResponse getSettingsResponse = new GetSettingsResponse(indexToSettings, new HashMap<>()); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArgument(2); + + listener.onResponse(getSettingsResponse); + return null; + }).when(client).execute(any(), any(), any()); + adIndices.update(); + verify(indicesAdminClient, never()).updateSettings(any(), any()); + } + + @SuppressWarnings("unchecked") + public void testJobIndexNotFound() { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArgument(2); + + listener.onFailure(new IndexNotFoundException(ADIndex.JOB.getIndexName())); + return null; + }).when(client).execute(any(), any(), any()); + + adIndices.update(); + verify(client, times(1)).execute(eq(GetSettingsAction.INSTANCE), any(), any()); + verify(indicesAdminClient, never()).updateSettings(any(), any()); + } + + @SuppressWarnings("unchecked") + public void testFailtoUpdateJobSetting() { + setUpSuccessfulGetJobSetting(); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArgument(2); + + listener.onFailure(new RuntimeException(ADIndex.JOB.getIndexName())); + return null; + }).when(indicesAdminClient).updateSettings(any(), any()); + + adIndices.update(); + verify(client, times(1)).execute(eq(GetSettingsAction.INSTANCE), any(), any()); + verify(indicesAdminClient, times(1)).updateSettings(any(), any()); + + // will have to do it again since last time we fail + adIndices.update(); + verify(client, times(2)).execute(eq(GetSettingsAction.INSTANCE), any(), any()); + verify(indicesAdminClient, times(2)).updateSettings(any(), any()); + } + + @SuppressWarnings("unchecked") + public void testTooManyUpdate() throws IOException { + setUpSuccessfulGetJobSetting(); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArgument(2); + + listener.onFailure(new RuntimeException(ADIndex.JOB.getIndexName())); + return null; + }).when(indicesAdminClient).updateSettings(any(), any()); + + adIndices = new ADIndexManagement(client, clusterService, threadPool, settings, nodeFilter, 1); + + adIndices.update(); + adIndices.update(); + + // even though we updated two times, since it passed the max retry limit (1), we won't retry + verify(client, times(1)).execute(eq(GetSettingsAction.INSTANCE), any(), any()); + } +} diff --git a/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java b/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java index f8da88458..e7d74f28a 100644 --- a/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java +++ b/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java @@ -33,7 +33,6 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.MemoryTracker; import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.feature.FeatureManager; @@ -54,6 +53,7 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.dataprocessor.Imputer; import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; @@ -169,7 +169,7 @@ public void setUp() throws Exception { AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, AnomalyDetectorSettings.HOURLY_MAINTENANCE, threadPool, - AnomalyDetectorPlugin.AD_THREAD_POOL_NAME + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME ); checkpointWriteQueue = mock(CheckpointWriteWorker.class); diff --git a/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java-e b/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java-e new file mode 100644 index 000000000..e7d74f28a --- /dev/null +++ b/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java-e @@ -0,0 +1,251 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; + +import java.time.Clock; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.ratelimit.CheckpointWriteWorker; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodeRole; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.test.ClusterServiceUtils; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.dataprocessor.Imputer; +import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; + +import com.google.common.collect.ImmutableList; + +public class AbstractCosineDataTest extends AbstractTimeSeriesTest { + int numMinSamples; + String modelId; + String entityName; + String detectorId; + ModelState modelState; + Clock clock; + float priority; + EntityColdStarter entityColdStarter; + NodeStateManager stateManager; + SearchFeatureDao searchFeatureDao; + Imputer imputer; + CheckpointDao checkpoint; + FeatureManager featureManager; + Settings settings; + ThreadPool threadPool; + AtomicBoolean released; + Runnable releaseSemaphore; + ActionListener listener; + CountDownLatch inProgressLatch; + CheckpointWriteWorker checkpointWriteQueue; + Entity entity; + AnomalyDetector detector; + long rcfSeed; + ModelManager modelManager; + ClientUtil clientUtil; + ClusterService clusterService; + ClusterSettings clusterSettings; + DiscoveryNode discoveryNode; + Set> nodestateSetting; + + @SuppressWarnings("unchecked") + @Override + public void setUp() throws Exception { + super.setUp(); + numMinSamples = AnomalyDetectorSettings.NUM_MIN_SAMPLES; + + clock = mock(Clock.class); + when(clock.instant()).thenReturn(Instant.now()); + + threadPool = mock(ThreadPool.class); + setUpADThreadPool(threadPool); + + settings = Settings.EMPTY; + + Client client = mock(Client.class); + clientUtil = mock(ClientUtil.class); + + detector = TestHelpers.AnomalyDetectorBuilder + .newInstance() + .setDetectionInterval(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES)) + .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) + .build(); + when(clock.millis()).thenReturn(1602401500000L); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, CommonName.CONFIG_INDEX)); + + return null; + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); + + nodestateSetting = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + nodestateSetting.add(MAX_RETRY_FOR_UNRESPONSIVE_NODE); + nodestateSetting.add(BACKOFF_MINUTES); + nodestateSetting.add(CHECKPOINT_SAVING_FREQ); + clusterSettings = new ClusterSettings(Settings.EMPTY, nodestateSetting); + + discoveryNode = new DiscoveryNode( + "node1", + OpenSearchTestCase.buildNewFakeTransportAddress(), + Collections.emptyMap(), + DiscoveryNodeRole.BUILT_IN_ROLES, + Version.CURRENT + ); + + clusterService = ClusterServiceUtils.createClusterService(threadPool, discoveryNode, clusterSettings); + + stateManager = new NodeStateManager( + client, + xContentRegistry(), + settings, + clientUtil, + clock, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + clusterService + ); + + imputer = new LinearUniformImputer(true); + + searchFeatureDao = mock(SearchFeatureDao.class); + checkpoint = mock(CheckpointDao.class); + + featureManager = new FeatureManager( + searchFeatureDao, + imputer, + clock, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, + AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, + AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, + AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, + AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, + AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME + ); + + checkpointWriteQueue = mock(CheckpointWriteWorker.class); + + rcfSeed = 2051L; + entityColdStarter = new EntityColdStarter( + clock, + threadPool, + stateManager, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.TIME_DECAY, + numMinSamples, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + imputer, + searchFeatureDao, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + featureManager, + settings, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + rcfSeed, + AnomalyDetectorSettings.MAX_COLD_START_ROUNDS + ); + + detectorId = "123"; + modelId = "123_entity_abc"; + entityName = "abc"; + priority = 0.3f; + entity = Entity.createSingleAttributeEntity("field", entityName); + + released = new AtomicBoolean(); + + inProgressLatch = new CountDownLatch(1); + releaseSemaphore = () -> { + released.set(true); + inProgressLatch.countDown(); + }; + listener = ActionListener.wrap(releaseSemaphore); + + modelManager = new ModelManager( + mock(CheckpointDao.class), + mock(Clock.class), + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.TIME_DECAY, + AnomalyDetectorSettings.NUM_MIN_SAMPLES, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + AnomalyDetectorSettings.MIN_PREVIEW_SIZE, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, + entityColdStarter, + mock(FeatureManager.class), + mock(MemoryTracker.class), + settings, + clusterService + ); + } + + protected void checkSemaphoreRelease() throws InterruptedException { + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + assertTrue(released.get()); + } + + public int searchInsert(long[] timestamps, long target) { + int pivot, left = 0, right = timestamps.length - 1; + while (left <= right) { + pivot = left + (right - left) / 2; + if (timestamps[pivot] == target) + return pivot; + if (target < timestamps[pivot]) + right = pivot - 1; + else + left = pivot + 1; + } + return left; + } +} diff --git a/src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java b/src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java index bd3ca3d5b..8c3e6c472 100644 --- a/src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java +++ b/src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java @@ -99,9 +99,9 @@ import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.util.ClientUtil; import org.opensearch.client.Client; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.engine.VersionConflictEngineException; -import org.opensearch.index.shard.ShardId; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.timeseries.constant.CommonName; diff --git a/src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java-e b/src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java-e new file mode 100644 index 000000000..f3d68e9ea --- /dev/null +++ b/src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java-e @@ -0,0 +1,1101 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.ad.ml.CheckpointDao.FIELD_MODELV2; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileReader; +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.Charset; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.time.Month; +import java.time.OffsetDateTime; +import java.time.ZoneOffset; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.Queue; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.pool2.BasePooledObjectFactory; +import org.apache.commons.pool2.PooledObject; +import org.apache.commons.pool2.impl.DefaultPooledObject; +import org.apache.commons.pool2.impl.GenericObjectPool; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.junit.Before; +import org.mockito.Answers; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.bulk.BulkAction; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.get.MultiGetAction; +import org.opensearch.action.get.MultiGetItemResponse; +import org.opensearch.action.get.MultiGetRequest; +import org.opensearch.action.get.MultiGetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.support.replication.ReplicationResponse; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.client.Client; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.constant.CommonName; + +import test.org.opensearch.ad.util.JsonDeserializer; +import test.org.opensearch.ad.util.MLUtil; +import test.org.opensearch.ad.util.RandomModelStateConfig; + +import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.config.Precision; +import com.amazon.randomcutforest.parkservices.AnomalyDescriptor; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; +import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper; +import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState; +import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV3StateConverter; +import com.amazon.randomcutforest.state.RandomCutForestMapper; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; + +import io.protostuff.LinkedBuffer; +import io.protostuff.Schema; +import io.protostuff.runtime.RuntimeSchema; + +public class CheckpointDaoTests extends OpenSearchTestCase { + private static final Logger logger = LogManager.getLogger(CheckpointDaoTests.class); + + private CheckpointDao checkpointDao; + + // dependencies + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private Client client; + + @Mock + private ClientUtil clientUtil; + + @Mock + private GetResponse getResponse; + + @Mock + private Clock clock; + + @Mock + private ADIndexManagement indexUtil; + + private Schema trcfSchema; + + // configuration + private String indexName; + + // test data + private String modelId; + + private Gson gson; + private Class thresholdingModelClass; + + private int maxCheckpointBytes = 1_000_000; + private GenericObjectPool serializeRCFBufferPool; + private RandomCutForestMapper mapper; + private ThresholdedRandomCutForestMapper trcfMapper; + private V1JsonToV3StateConverter converter; + double anomalyRate; + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + + indexName = "testIndexName"; + + // gson = PowerMockito.mock(Gson.class); + gson = new GsonBuilder().serializeSpecialFloatingPointValues().create(); + + thresholdingModelClass = HybridThresholdingModel.class; + + when(clock.instant()).thenReturn(Instant.now()); + + mapper = new RandomCutForestMapper(); + mapper.setSaveExecutorContextEnabled(true); + + trcfMapper = new ThresholdedRandomCutForestMapper(); + trcfSchema = AccessController + .doPrivileged( + (PrivilegedAction>) () -> RuntimeSchema + .getSchema(ThresholdedRandomCutForestState.class) + ); + + converter = new V1JsonToV3StateConverter(); + + serializeRCFBufferPool = spy(AccessController.doPrivileged(new PrivilegedAction>() { + @Override + public GenericObjectPool run() { + return new GenericObjectPool<>(new BasePooledObjectFactory() { + @Override + public LinkedBuffer create() throws Exception { + return LinkedBuffer.allocate(AnomalyDetectorSettings.SERIALIZATION_BUFFER_BYTES); + } + + @Override + public PooledObject wrap(LinkedBuffer obj) { + return new DefaultPooledObject<>(obj); + } + }); + } + })); + serializeRCFBufferPool.setMaxTotal(AnomalyDetectorSettings.MAX_TOTAL_RCF_SERIALIZATION_BUFFERS); + serializeRCFBufferPool.setMaxIdle(AnomalyDetectorSettings.MAX_TOTAL_RCF_SERIALIZATION_BUFFERS); + serializeRCFBufferPool.setMinIdle(0); + serializeRCFBufferPool.setBlockWhenExhausted(false); + serializeRCFBufferPool.setTimeBetweenEvictionRuns(AnomalyDetectorSettings.HOURLY_MAINTENANCE); + + anomalyRate = 0.005; + checkpointDao = new CheckpointDao( + client, + clientUtil, + indexName, + gson, + mapper, + converter, + trcfMapper, + trcfSchema, + thresholdingModelClass, + indexUtil, + maxCheckpointBytes, + serializeRCFBufferPool, + AnomalyDetectorSettings.SERIALIZATION_BUFFER_BYTES, + anomalyRate + ); + + when(indexUtil.doesCheckpointIndexExist()).thenReturn(true); + + modelId = "testModelId"; + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + serializeRCFBufferPool.close(); + } + + private ThresholdedRandomCutForest createTRCF() { + int dimensions = 4; + int numberOfTrees = 1; + int sampleSize = 256; + int dataSize = 10 * sampleSize; + Random random = new Random(); + long seed = random.nextLong(); + double[][] data = MLUtil.generateShingledData(dataSize, dimensions, 2); + ThresholdedRandomCutForest forest = ThresholdedRandomCutForest + .builder() + .compact(true) + .dimensions(dimensions) + .numberOfTrees(numberOfTrees) + .sampleSize(sampleSize) + .precision(Precision.FLOAT_32) + .randomSeed(seed) + .boundingBoxCacheFraction(0) + .build(); + for (double[] point : data) { + forest.process(point, 0); + } + return forest; + } + + @SuppressWarnings("unchecked") + private void verifyPutModelCheckpointAsync() { + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(UpdateRequest.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(null); + return null; + }).when(clientUtil).asyncRequest(requestCaptor.capture(), any(BiConsumer.class), any(ActionListener.class)); + + ActionListener listener = mock(ActionListener.class); + + checkpointDao.putTRCFCheckpoint(modelId, createTRCF(), listener); + + UpdateRequest updateRequest = requestCaptor.getValue(); + assertEquals(indexName, updateRequest.index()); + assertEquals(modelId, updateRequest.id()); + IndexRequest indexRequest = updateRequest.doc(); + Set expectedSourceKeys = new HashSet(Arrays.asList(FIELD_MODELV2, CommonName.TIMESTAMP)); + assertEquals(expectedSourceKeys, indexRequest.sourceAsMap().keySet()); + assertTrue(!((String) (indexRequest.sourceAsMap().get(FIELD_MODELV2))).isEmpty()); + assertNotNull(indexRequest.sourceAsMap().get(CommonName.TIMESTAMP)); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Void.class); + verify(listener).onResponse(responseCaptor.capture()); + Void response = responseCaptor.getValue(); + assertEquals(null, response); + } + + public void test_putModelCheckpoint_callListener_whenCompleted() { + verifyPutModelCheckpointAsync(); + } + + public void test_putModelCheckpoint_callListener_no_checkpoint_index() { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(new CreateIndexResponse(true, true, ADCommonName.CHECKPOINT_INDEX_NAME)); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + verifyPutModelCheckpointAsync(); + } + + public void test_putModelCheckpoint_callListener_race_condition() { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new ResourceAlreadyExistsException(ADCommonName.CHECKPOINT_INDEX_NAME)); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + verifyPutModelCheckpointAsync(); + } + + @SuppressWarnings("unchecked") + public void test_putModelCheckpoint_callListener_unexpected_exception() { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new RuntimeException("")); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + ActionListener listener = mock(ActionListener.class); + checkpointDao.putTRCFCheckpoint(modelId, createTRCF(), listener); + + verify(clientUtil, never()).asyncRequest(any(), any(), any()); + } + + @SuppressWarnings("unchecked") + public void test_getModelCheckpoint_returnExpectedToListener() { + // ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(GetRequest.class); + UpdateResponse updateResponse = new UpdateResponse( + new ReplicationResponse.ShardInfo(3, 2), + new ShardId(ADCommonName.CHECKPOINT_INDEX_NAME, "uuid", 2), + "1", + 7, + 17, + 2, + UPDATED + ); + AtomicReference getRequest = new AtomicReference<>(); + doAnswer(invocation -> { + ActionRequest request = invocation.getArgument(0); + if (request instanceof GetRequest) { + getRequest.set((GetRequest) request); + ActionListener listener = invocation.getArgument(2); + listener.onResponse(getResponse); + } else { + UpdateRequest updateRequest = (UpdateRequest) request; + when(getResponse.getSource()).thenReturn(updateRequest.doc().sourceAsMap()); + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + } + return null; + }).when(clientUtil).asyncRequest(any(), any(BiConsumer.class), any(ActionListener.class)); + when(getResponse.isExists()).thenReturn(true); + + ThresholdedRandomCutForest trcf = createTRCF(); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + checkpointDao.putTRCFCheckpoint(modelId, trcf, ActionListener.wrap(response -> { inProgressLatch.countDown(); }, exception -> { + assertTrue("Should not reach here ", false); + inProgressLatch.countDown(); + })); + + ActionListener> listener = mock(ActionListener.class); + checkpointDao.getTRCFModel(modelId, listener); + + GetRequest capturedGetRequest = getRequest.get(); + assertEquals(indexName, capturedGetRequest.index()); + assertEquals(modelId, capturedGetRequest.id()); + ArgumentCaptor> responseCaptor = ArgumentCaptor.forClass(Optional.class); + verify(listener).onResponse(responseCaptor.capture()); + Optional result = responseCaptor.getValue(); + assertTrue(result.isPresent()); + RandomCutForest deserializedForest = result.get().getForest(); + RandomCutForest serializedForest = trcf.getForest(); + assertEquals(deserializedForest.getDimensions(), serializedForest.getDimensions()); + assertEquals(deserializedForest.getNumberOfTrees(), serializedForest.getNumberOfTrees()); + assertEquals(deserializedForest.getSampleSize(), serializedForest.getSampleSize()); + } + + @SuppressWarnings("unchecked") + public void test_getModelCheckpoint_Bwc() { + // ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(GetRequest.class); + UpdateResponse updateResponse = new UpdateResponse( + new ReplicationResponse.ShardInfo(3, 2), + new ShardId(ADCommonName.CHECKPOINT_INDEX_NAME, "uuid", 2), + "1", + 7, + 17, + 2, + UPDATED + ); + AtomicReference getRequest = new AtomicReference<>(); + doAnswer(invocation -> { + ActionRequest request = invocation.getArgument(0); + if (request instanceof GetRequest) { + getRequest.set((GetRequest) request); + ActionListener listener = invocation.getArgument(2); + listener.onResponse(getResponse); + } else { + UpdateRequest updateRequest = (UpdateRequest) request; + when(getResponse.getSource()).thenReturn(updateRequest.doc().sourceAsMap()); + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + } + return null; + }).when(clientUtil).asyncRequest(any(), any(BiConsumer.class), any(ActionListener.class)); + when(getResponse.isExists()).thenReturn(true); + + ThresholdedRandomCutForest trcf = createTRCF(); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + checkpointDao.putTRCFCheckpoint(modelId, trcf, ActionListener.wrap(response -> { inProgressLatch.countDown(); }, exception -> { + assertTrue("Should not reach here ", false); + inProgressLatch.countDown(); + })); + + ActionListener> listener = mock(ActionListener.class); + checkpointDao.getTRCFModel(modelId, listener); + + GetRequest capturedGetRequest = getRequest.get(); + assertEquals(indexName, capturedGetRequest.index()); + assertEquals(modelId, capturedGetRequest.id()); + ArgumentCaptor> responseCaptor = ArgumentCaptor.forClass(Optional.class); + verify(listener).onResponse(responseCaptor.capture()); + Optional result = responseCaptor.getValue(); + assertTrue(result.isPresent()); + RandomCutForest deserializedForest = result.get().getForest(); + RandomCutForest serializedForest = trcf.getForest(); + assertEquals(deserializedForest.getDimensions(), serializedForest.getDimensions()); + assertEquals(deserializedForest.getNumberOfTrees(), serializedForest.getNumberOfTrees()); + assertEquals(deserializedForest.getSampleSize(), serializedForest.getSampleSize()); + } + + @SuppressWarnings("unchecked") + public void test_getModelCheckpoint_returnEmptyToListener_whenModelNotFound() { + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(GetRequest.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(getResponse); + return null; + }).when(clientUtil).asyncRequest(requestCaptor.capture(), any(BiConsumer.class), any(ActionListener.class)); + when(getResponse.isExists()).thenReturn(false); + + ActionListener> listener = mock(ActionListener.class); + checkpointDao.getTRCFModel(modelId, listener); + + GetRequest getRequest = requestCaptor.getValue(); + assertEquals(indexName, getRequest.index()); + assertEquals(modelId, getRequest.id()); + // ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Exception.class); + // verify(listener).onFailure(responseCaptor.capture()); + // Exception exception = responseCaptor.getValue(); + // assertTrue(exception instanceof ResourceNotFoundException); + ArgumentCaptor> responseCaptor = ArgumentCaptor.forClass(Optional.class); + verify(listener).onResponse(responseCaptor.capture()); + assertTrue(!responseCaptor.getValue().isPresent()); + } + + @SuppressWarnings("unchecked") + public void test_deleteModelCheckpoint_callListener_whenCompleted() { + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(DeleteRequest.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(null); + return null; + }).when(clientUtil).asyncRequest(requestCaptor.capture(), any(BiConsumer.class), any(ActionListener.class)); + + ActionListener listener = mock(ActionListener.class); + checkpointDao.deleteModelCheckpoint(modelId, listener); + + DeleteRequest deleteRequest = requestCaptor.getValue(); + assertEquals(indexName, deleteRequest.index()); + assertEquals(modelId, deleteRequest.id()); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Void.class); + verify(listener).onResponse(responseCaptor.capture()); + Void response = responseCaptor.getValue(); + assertEquals(null, response); + } + + @SuppressWarnings("unchecked") + public void test_restore() throws IOException { + ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + EntityModel modelToSave = state.getModel(); + + GetResponse getResponse = mock(GetResponse.class); + when(getResponse.isExists()).thenReturn(true); + Map source = new HashMap<>(); + source.put(CheckpointDao.DETECTOR_ID, state.getId()); + source.put(CheckpointDao.FIELD_MODELV2, checkpointDao.toCheckpoint(modelToSave, modelId).get()); + source.put(CommonName.TIMESTAMP, "2020-10-11T22:58:23.610392Z"); + when(getResponse.getSource()).thenReturn(source); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(getResponse); + return null; + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(BiConsumer.class), any(ActionListener.class)); + + ActionListener>> listener = mock(ActionListener.class); + checkpointDao.deserializeModelCheckpoint(modelId, listener); + + ArgumentCaptor>> responseCaptor = ArgumentCaptor.forClass(Optional.class); + verify(listener).onResponse(responseCaptor.capture()); + Optional> response = responseCaptor.getValue(); + assertTrue(response.isPresent()); + Entry entry = response.get(); + OffsetDateTime utcTime = entry.getValue().atOffset(ZoneOffset.UTC); + assertEquals(2020, utcTime.getYear()); + assertEquals(Month.OCTOBER, utcTime.getMonth()); + assertEquals(11, utcTime.getDayOfMonth()); + assertEquals(22, utcTime.getHour()); + assertEquals(58, utcTime.getMinute()); + assertEquals(23, utcTime.getSecond()); + + EntityModel model = entry.getKey(); + Queue queue = model.getSamples(); + Queue samplesToSave = modelToSave.getSamples(); + assertEquals(samplesToSave.size(), queue.size()); + assertTrue(Arrays.equals(samplesToSave.peek(), queue.peek())); + logger.info(modelToSave.getTrcf()); + logger.info(model.getTrcf()); + assertEquals(modelToSave.getTrcf().get().getForest().getTotalUpdates(), model.getTrcf().get().getForest().getTotalUpdates()); + } + + public void test_batch_write_no_index() { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + checkpointDao.batchWrite(new BulkRequest(), null); + verify(indexUtil, times(1)).initCheckpointIndex(any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(new CreateIndexResponse(true, true, ADCommonName.CHECKPOINT_INDEX_NAME)); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + checkpointDao.batchWrite(new BulkRequest(), null); + verify(clientUtil, times(1)).execute(any(), any(), any()); + } + + public void test_batch_write_index_init_no_ack() throws InterruptedException { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(new CreateIndexResponse(false, false, ADCommonName.CHECKPOINT_INDEX_NAME)); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + final CountDownLatch processingLatch = new CountDownLatch(1); + checkpointDao.batchWrite(new BulkRequest(), ActionListener.wrap(response -> assertTrue(false), e -> { + assertTrue(e.getMessage(), e != null); + processingLatch.countDown(); + })); + + processingLatch.await(100, TimeUnit.SECONDS); + } + + public void test_batch_write_index_already_exists() { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new ResourceAlreadyExistsException("blah")); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + checkpointDao.batchWrite(new BulkRequest(), null); + verify(clientUtil, times(1)).execute(any(), any(), any()); + } + + public void test_batch_write_init_exception() throws InterruptedException { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(false); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new RuntimeException("blah")); + return null; + }).when(indexUtil).initCheckpointIndex(any()); + + final CountDownLatch processingLatch = new CountDownLatch(1); + checkpointDao.batchWrite(new BulkRequest(), ActionListener.wrap(response -> assertTrue(false), e -> { + assertTrue(e.getMessage(), e != null); + processingLatch.countDown(); + })); + + processingLatch.await(100, TimeUnit.SECONDS); + } + + private BulkResponse createBulkResponse(int succeeded, int failed, String[] failedId) { + BulkItemResponse[] bulkItemResponses = new BulkItemResponse[succeeded + failed]; + + ShardId shardId = new ShardId(ADCommonName.CHECKPOINT_INDEX_NAME, "", 1); + int i = 0; + for (; i < failed; i++) { + bulkItemResponses[i] = new BulkItemResponse( + i, + DocWriteRequest.OpType.UPDATE, + new BulkItemResponse.Failure( + ADCommonName.CHECKPOINT_INDEX_NAME, + failedId[i], + new VersionConflictEngineException(shardId, "id", "test") + ) + ); + } + + for (; i < failed + succeeded; i++) { + bulkItemResponses[i] = new BulkItemResponse( + i, + DocWriteRequest.OpType.UPDATE, + new UpdateResponse(shardId, "1", 0L, 1L, 1L, DocWriteResponse.Result.CREATED) + ); + } + + return new BulkResponse(bulkItemResponses, 507); + } + + @SuppressWarnings("unchecked") + public void test_batch_write_no_init() throws InterruptedException { + when(indexUtil.doesCheckpointIndexExist()).thenReturn(true); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(createBulkResponse(2, 0, null)); + return null; + }).when(clientUtil).execute(eq(BulkAction.INSTANCE), any(BulkRequest.class), any(ActionListener.class)); + + final CountDownLatch processingLatch = new CountDownLatch(1); + checkpointDao + .batchWrite(new BulkRequest(), ActionListener.wrap(response -> processingLatch.countDown(), e -> { assertTrue(false); })); + + // we don't expect the waiting time elapsed before the count reached zero + assertTrue(processingLatch.await(100, TimeUnit.SECONDS)); + verify(clientUtil, times(1)).execute(any(), any(), any()); + } + + @SuppressWarnings("unchecked") + public void test_batch_read() throws InterruptedException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + + MultiGetItemResponse[] items = new MultiGetItemResponse[1]; + items[0] = new MultiGetItemResponse( + null, + new MultiGetResponse.Failure( + ADCommonName.CHECKPOINT_INDEX_NAME, + "modelId", + new IndexNotFoundException(ADCommonName.CHECKPOINT_INDEX_NAME) + ) + ); + listener.onResponse(new MultiGetResponse(items)); + return null; + }).when(clientUtil).execute(eq(MultiGetAction.INSTANCE), any(MultiGetRequest.class), any(ActionListener.class)); + + final CountDownLatch processingLatch = new CountDownLatch(1); + checkpointDao + .batchRead(new MultiGetRequest(), ActionListener.wrap(response -> processingLatch.countDown(), e -> { assertTrue(false); })); + + // we don't expect the waiting time elapsed before the count reached zero + assertTrue(processingLatch.await(100, TimeUnit.SECONDS)); + verify(clientUtil, times(1)).execute(any(), any(), any()); + } + + public void test_too_large_checkpoint() throws IOException { + checkpointDao = new CheckpointDao( + client, + clientUtil, + indexName, + gson, + mapper, + converter, + trcfMapper, + trcfSchema, + thresholdingModelClass, + indexUtil, + 1, // make the max checkpoint size 1 byte only + serializeRCFBufferPool, + AnomalyDetectorSettings.SERIALIZATION_BUFFER_BYTES, + anomalyRate + ); + + ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + + assertTrue(checkpointDao.toIndexSource(state).isEmpty()); + } + + public void test_to_index_source() throws IOException { + ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + + Map source = checkpointDao.toIndexSource(state); + assertTrue(!source.isEmpty()); + for (Object obj : source.values()) { + // Opensearch cannot recognize Optional + assertTrue(!(obj instanceof Optional)); + } + } + + @SuppressWarnings("unchecked") + public void testBorrowFromPoolFailure() throws Exception { + GenericObjectPool mockSerializeRCFBufferPool = mock(GenericObjectPool.class); + when(mockSerializeRCFBufferPool.borrowObject()).thenThrow(NoSuchElementException.class); + checkpointDao = new CheckpointDao( + client, + clientUtil, + indexName, + gson, + mapper, + converter, + trcfMapper, + trcfSchema, + thresholdingModelClass, + indexUtil, + 1, // make the max checkpoint size 1 byte only + mockSerializeRCFBufferPool, + AnomalyDetectorSettings.SERIALIZATION_BUFFER_BYTES, + anomalyRate + ); + + ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + assertTrue(!checkpointDao.toCheckpoint(state.getModel(), modelId).get().isEmpty()); + } + + public void testMapperFailure() throws IOException { + ThresholdedRandomCutForestMapper mockMapper = mock(ThresholdedRandomCutForestMapper.class); + when(mockMapper.toState(any())).thenThrow(RuntimeException.class); + + checkpointDao = new CheckpointDao( + client, + clientUtil, + indexName, + gson, + mapper, + converter, + mockMapper, + trcfSchema, + thresholdingModelClass, + indexUtil, + 1, // make the max checkpoint size 1 byte only + serializeRCFBufferPool, + AnomalyDetectorSettings.SERIALIZATION_BUFFER_BYTES, + anomalyRate + ); + + // make sure sample size is not 0 otherwise sample size won't be written to checkpoint + ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(1).build()); + String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); + assertEquals(null, JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + assertTrue(null != JsonDeserializer.getChildNode(json, CommonName.ENTITY_SAMPLE)); + // assertTrue(null != JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_THRESHOLD)); + // assertNotNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + } + + public void testEmptySample() throws IOException { + ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); + String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); + // assertTrue(null != JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + assertEquals(null, JsonDeserializer.getChildNode(json, CommonName.ENTITY_SAMPLE)); + // assertTrue(null != JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_THRESHOLD)); + assertNotNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + } + + public void testToCheckpointErcfCheckoutFail() throws Exception { + when(serializeRCFBufferPool.borrowObject()).thenThrow(RuntimeException.class); + + ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); + String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); + + assertNotNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + } + + @SuppressWarnings("unchecked") + private void setUpMockTrcf() { + trcfMapper = mock(ThresholdedRandomCutForestMapper.class); + trcfSchema = mock(Schema.class); + checkpointDao = new CheckpointDao( + client, + clientUtil, + indexName, + gson, + mapper, + converter, + trcfMapper, + trcfSchema, + thresholdingModelClass, + indexUtil, + maxCheckpointBytes, + serializeRCFBufferPool, + AnomalyDetectorSettings.SERIALIZATION_BUFFER_BYTES, + anomalyRate + ); + } + + public void testToCheckpointTrcfCheckoutBufferFail() throws Exception { + setUpMockTrcf(); + when(trcfMapper.toState(any())).thenThrow(RuntimeException.class).thenReturn(null); + + ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); + String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); + + assertNotNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + } + + public void testToCheckpointTrcfFailNewBuffer() throws Exception { + setUpMockTrcf(); + doReturn(null).when(serializeRCFBufferPool).borrowObject(); + when(trcfMapper.toState(any())).thenThrow(RuntimeException.class); + + ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); + String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); + + assertNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + } + + public void testToCheckpointTrcfCheckoutBufferInvalidateFail() throws Exception { + setUpMockTrcf(); + when(trcfMapper.toState(any())).thenThrow(RuntimeException.class).thenReturn(null); + doThrow(RuntimeException.class).when(serializeRCFBufferPool).invalidateObject(any()); + + ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); + String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); + + assertNotNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + } + + public void testFromEntityModelCheckpointWithTrcf() throws Exception { + ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + String model = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); + + Map entity = new HashMap<>(); + entity.put(FIELD_MODELV2, model); + entity.put(CommonName.TIMESTAMP, Instant.now().toString()); + Optional> result = checkpointDao.fromEntityModelCheckpoint(entity, this.modelId); + + assertTrue(result.isPresent()); + Entry pair = result.get(); + EntityModel entityModel = pair.getKey(); + assertTrue(entityModel.getTrcf().isPresent()); + } + + public void testFromEntityModelCheckpointTrcfMapperFail() throws Exception { + setUpMockTrcf(); + when(trcfMapper.toModel(any())).thenThrow(RuntimeException.class); + ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + String model = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); + + Map entity = new HashMap<>(); + entity.put(FIELD_MODELV2, model); + entity.put(CommonName.TIMESTAMP, Instant.now().toString()); + Optional> result = checkpointDao.fromEntityModelCheckpoint(entity, this.modelId); + + assertTrue(result.isPresent()); + Entry pair = result.get(); + EntityModel entityModel = pair.getKey(); + assertFalse(entityModel.getTrcf().isPresent()); + } + + private Pair, Instant> setUp1_0Model(String checkpointFileName) throws FileNotFoundException, + IOException, + URISyntaxException { + String model = null; + try ( + FileReader v1CheckpointFile = new FileReader( + new File(getClass().getResource(checkpointFileName).toURI()), + Charset.defaultCharset() + ); + BufferedReader rr = new BufferedReader(v1CheckpointFile) + ) { + model = rr.readLine(); + } + + Instant now = Instant.now(); + Map entity = new HashMap<>(); + entity.put(CommonName.FIELD_MODEL, model); + entity.put(CommonName.TIMESTAMP, now.toString()); + return Pair.of(entity, now); + } + + public void testFromEntityModelCheckpointBWC() throws FileNotFoundException, IOException, URISyntaxException { + Pair, Instant> modelPair = setUp1_0Model("checkpoint_2.json"); + Instant now = modelPair.getRight(); + + Optional> result = checkpointDao.fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); + assertTrue(result.isPresent()); + Entry pair = result.get(); + assertEquals(now, pair.getValue()); + + EntityModel entityModel = pair.getKey(); + + Queue samples = entityModel.getSamples(); + assertEquals(6, samples.size()); + double[] firstSample = samples.peek(); + assertEquals(1, firstSample.length); + assertEquals(0.6832234717598454, firstSample[0], 1e-10); + + ThresholdedRandomCutForest trcf = entityModel.getTrcf().get(); + RandomCutForest forest = trcf.getForest(); + assertEquals(1, forest.getDimensions()); + assertEquals(10, forest.getNumberOfTrees()); + assertEquals(256, forest.getSampleSize()); + // there are at least 10 scores in the checkpoint + assertTrue(trcf.getThresholder().getCount() > 10); + + Random random = new Random(0); + for (int i = 0; i < 100; i++) { + double[] point = getPoint(forest.getDimensions(), random); + double score = trcf.process(point, 0).getRCFScore(); + assertTrue(score > 0); + forest.update(point); + } + } + + public void testFromEntityModelCheckpointModelTooLarge() throws FileNotFoundException, IOException, URISyntaxException { + Pair, Instant> modelPair = setUp1_0Model("checkpoint_2.json"); + checkpointDao = new CheckpointDao( + client, + clientUtil, + indexName, + gson, + mapper, + converter, + trcfMapper, + trcfSchema, + thresholdingModelClass, + indexUtil, + 100_000, // checkpoint_2.json is of 224603 bytes. + serializeRCFBufferPool, + AnomalyDetectorSettings.SERIALIZATION_BUFFER_BYTES, + anomalyRate + ); + Optional> result = checkpointDao.fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); + // checkpoint is only configured to take in 1 MB checkpoint at most. But the checkpoint here is of 1408047 bytes. + assertTrue(!result.isPresent()); + } + + // test no model is present in checkpoint + public void testFromEntityModelCheckpointEmptyModel() throws FileNotFoundException, IOException, URISyntaxException { + Map entity = new HashMap<>(); + entity.put(CommonName.TIMESTAMP, Instant.now().toString()); + + Optional> result = checkpointDao.fromEntityModelCheckpoint(entity, this.modelId); + assertTrue(!result.isPresent()); + } + + public void testFromEntityModelCheckpointEmptySamples() throws FileNotFoundException, IOException, URISyntaxException { + Pair, Instant> modelPair = setUp1_0Model("checkpoint_1.json"); + Optional> result = checkpointDao.fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); + assertTrue(result.isPresent()); + Queue samples = result.get().getKey().getSamples(); + assertEquals(0, samples.size()); + } + + public void testFromEntityModelCheckpointNoRCF() throws FileNotFoundException, IOException, URISyntaxException { + Pair, Instant> modelPair = setUp1_0Model("checkpoint_3.json"); + Optional> result = checkpointDao.fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); + assertTrue(result.isPresent()); + assertTrue(!result.get().getKey().getTrcf().isPresent()); + } + + public void testFromEntityModelCheckpointNoThreshold() throws FileNotFoundException, IOException, URISyntaxException { + Pair, Instant> modelPair = setUp1_0Model("checkpoint_4.json"); + Optional> result = checkpointDao.fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); + assertTrue(result.isPresent()); + + ThresholdedRandomCutForest trcf = result.get().getKey().getTrcf().get(); + RandomCutForest forest = trcf.getForest(); + assertEquals(1, forest.getDimensions()); + assertEquals(10, forest.getNumberOfTrees()); + assertEquals(256, forest.getSampleSize()); + // there are no scores in the checkpoint + assertEquals(0, trcf.getThresholder().getCount()); + } + + public void testFromEntityModelCheckpointWithEntity() throws Exception { + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).entityAttributes(true).build()); + Map content = checkpointDao.toIndexSource(state); + // Opensearch will convert from java.time.ZonedDateTime to String. Here I am converting to simulate that + content.put(CommonName.TIMESTAMP, "2021-09-23T05:00:37.93195Z"); + + Optional> result = checkpointDao.fromEntityModelCheckpoint(content, this.modelId); + + assertTrue(result.isPresent()); + Entry pair = result.get(); + EntityModel entityModel = pair.getKey(); + assertTrue(entityModel.getEntity().isPresent()); + assertEquals(state.getModel().getEntity().get(), entityModel.getEntity().get()); + } + + private double[] getPoint(int dimensions, Random random) { + double[] point = new double[dimensions]; + for (int i = 0; i < point.length; i++) { + point[i] = random.nextDouble(); + } + return point; + } + + // The checkpoint used for this test is from a single-stream detector + public void testDeserializeRCFModelPreINIT() throws Exception { + // Model in file 1_3_0_rcf_model_pre_init.json not passed initialization yet + URI uri = ClassLoader.getSystemResource("org/opensearch/ad/ml/1_3_0_rcf_model_pre_init.json").toURI(); + String filePath = Paths.get(uri).toString(); + String json = Files.readString(Paths.get(filePath), Charset.defaultCharset()); + Map map = gson.fromJson(json, Map.class); + String model = (String) ((Map) ((Map) ((ArrayList) ((Map) map.get("hits")).get("hits")).get(0)).get("_source")).get("modelV2"); + ThresholdedRandomCutForest forest = checkpointDao.toTrcf(model); + assertEquals(256, forest.getForest().getSampleSize()); + assertEquals(8, forest.getForest().getShingleSize()); + assertEquals(30, forest.getForest().getNumberOfTrees()); + } + + // The checkpoint used for this test is from a single-stream detector + public void testDeserializeRCFModelPostINIT() throws Exception { + // Model in file rc1_model_single_running is from RCF-3.0-rc1 + URI uri = ClassLoader.getSystemResource("org/opensearch/ad/ml/rc1_model_single_running.json").toURI(); + String filePath = Paths.get(uri).toString(); + String json = Files.readString(Paths.get(filePath), Charset.defaultCharset()); + Map map = gson.fromJson(json, Map.class); + String model = (String) ((Map) ((Map) ((ArrayList) ((Map) map.get("hits")).get("hits")).get(0)).get("_source")).get("modelV2"); + ThresholdedRandomCutForest forest = checkpointDao.toTrcf(model); + assertEquals(256, forest.getForest().getSampleSize()); + assertEquals(8, forest.getForest().getShingleSize()); + assertEquals(30, forest.getForest().getNumberOfTrees()); + } + + // This test is intended to check if given a checkpoint created by RCF-3.0-rc1 ("rc1_trcf_model_direct.json") + // and given the same sample data will rc1 and current RCF version (this test originally created when 3.0-rc2.1 is in use) + // will produce the same anomaly scores and grades. + // The scores and grades in this method were produced from AD running with RCF3.0-rc1 dependency + // and this test runs with the most recent RCF dependency that is being pulled by this project. + public void testDeserializeTRCFModel() throws Exception { + // Model in file rc1_trcf_model_direct is a checkpoint creatd by RCF-3.0-rc1 + URI uri = ClassLoader.getSystemResource("org/opensearch/ad/ml/rc1_trcf_model_direct.json").toURI(); + String filePath = Paths.get(uri).toString(); + String json = Files.readString(Paths.get(filePath), Charset.defaultCharset()); + // For the parsing of .toTrcf to work I had to manually change "\u003d" in code back to =. + // In the byte array it doesn't seem like this is an issue but whenever reading the byte array response into a file it + // converts "=" to "\u003d" https://groups.google.com/g/google-gson/c/JDHUo9DWyyM?pli=1 + // I also needed to bypass the trcf as it wasn't being read as a key value but instead part of the string + Map map = gson.fromJson(json, Map.class); + String model = (String) ((Map) ((Map) ((ArrayList) ((Map) map.get("hits")).get("hits")).get(0)).get("_source")).get("modelV2"); + model = model.split(":")[1].substring(1); + ThresholdedRandomCutForest forest = checkpointDao.toTrcf(model); + + List coldStartData = new ArrayList<>(); + double[] sample1 = new double[] { 57.0 }; + double[] sample2 = new double[] { 1.0 }; + double[] sample3 = new double[] { -19.0 }; + double[] sample4 = new double[] { 13.0 }; + double[] sample5 = new double[] { 41.0 }; + + coldStartData.add(sample1); + coldStartData.add(sample2); + coldStartData.add(sample3); + coldStartData.add(sample4); + coldStartData.add(sample5); + + // This scores were generated with the sample data but on RCF3.0-rc1 and we are comparing them + // to the scores generated by the imported RCF3.0-rc2.1 + List scores = new ArrayList<>(); + scores.add(4.814651669367903); + scores.add(5.566968073093689); + scores.add(5.919907610660049); + scores.add(5.770278090352401); + scores.add(5.319779117320102); + + List grade = new ArrayList<>(); + grade.add(1.0); + grade.add(0.0); + grade.add(0.0); + grade.add(0.0); + grade.add(0.0); + for (int i = 0; i < coldStartData.size(); i++) { + forest.process(coldStartData.get(i), 0); + AnomalyDescriptor descriptor = forest.process(coldStartData.get(i), 0); + assertEquals(descriptor.getRCFScore(), scores.get(i), 1e-9); + assertEquals(descriptor.getAnomalyGrade(), grade.get(i), 1e-9); + } + } + + public void testShouldSave() { + assertTrue(!checkpointDao.shouldSave(Instant.MIN, false, null, clock)); + assertTrue(checkpointDao.shouldSave(Instant.ofEpochMilli(Instant.now().toEpochMilli()), true, Duration.ofHours(6), clock)); + // now + 6 hrs > Instant.now + assertTrue(!checkpointDao.shouldSave(Instant.ofEpochMilli(Instant.now().toEpochMilli()), false, Duration.ofHours(6), clock)); + // 1658863778000L + 6 hrs < Instant.now + assertTrue(checkpointDao.shouldSave(Instant.ofEpochMilli(1658863778000L), false, Duration.ofHours(6), clock)); + } +} diff --git a/src/test/java/org/opensearch/ad/ml/CheckpointDeleteTests.java-e b/src/test/java/org/opensearch/ad/ml/CheckpointDeleteTests.java-e new file mode 100644 index 000000000..c94c145cb --- /dev/null +++ b/src/test/java/org/opensearch/ad/ml/CheckpointDeleteTests.java-e @@ -0,0 +1,183 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Locale; + +import org.apache.commons.pool2.impl.GenericObjectPool; +import org.junit.After; +import org.junit.Before; +import org.mockito.Mock; +import org.opensearch.OpenSearchException; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.client.Client; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.reindex.BulkByScrollResponse; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.ScrollableHitSource; +import org.opensearch.timeseries.AbstractTimeSeriesTest; + +import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper; +import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState; +import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV3StateConverter; +import com.amazon.randomcutforest.state.RandomCutForestMapper; +import com.google.gson.Gson; + +import io.protostuff.LinkedBuffer; +import io.protostuff.Schema; + +/** + * CheckpointDaoTests cannot extends basic ES test case and I cannot check logs + * written during test running using functions in ADAbstractTest. Create a new + * class for tests requiring checking logs. + * + */ +public class CheckpointDeleteTests extends AbstractTimeSeriesTest { + private enum DeleteExecutionMode { + NORMAL, + INDEX_NOT_FOUND, + FAILURE, + PARTIAL_FAILURE + } + + private CheckpointDao checkpointDao; + private Client client; + private ClientUtil clientUtil; + private Gson gson; + private ADIndexManagement indexUtil; + private String detectorId; + private int maxCheckpointBytes; + private GenericObjectPool objectPool; + + @Mock + private ThresholdedRandomCutForestMapper ercfMapper; + + @Mock + private Schema ercfSchema; + + double anomalyRate; + + @SuppressWarnings("unchecked") + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + super.setUpLog4jForJUnit(CheckpointDao.class); + + client = mock(Client.class); + clientUtil = mock(ClientUtil.class); + gson = null; + indexUtil = mock(ADIndexManagement.class); + detectorId = "123"; + maxCheckpointBytes = 1_000_000; + + RandomCutForestMapper mapper = mock(RandomCutForestMapper.class); + V1JsonToV3StateConverter converter = mock(V1JsonToV3StateConverter.class); + + objectPool = mock(GenericObjectPool.class); + int deserializeRCFBufferSize = 512; + anomalyRate = 0.005; + checkpointDao = new CheckpointDao( + client, + clientUtil, + ADCommonName.CHECKPOINT_INDEX_NAME, + gson, + mapper, + converter, + ercfMapper, + ercfSchema, + HybridThresholdingModel.class, + indexUtil, + maxCheckpointBytes, + objectPool, + deserializeRCFBufferSize, + anomalyRate + ); + } + + @Override + @After + public void tearDown() throws Exception { + super.tearDown(); + super.tearDownLog4jForJUnit(); + } + + @SuppressWarnings("unchecked") + public void delete_by_detector_id_template(DeleteExecutionMode mode) { + long deletedDocNum = 10L; + BulkByScrollResponse deleteByQueryResponse = mock(BulkByScrollResponse.class); + when(deleteByQueryResponse.getDeleted()).thenReturn(deletedDocNum); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length >= 3 + ); + assertTrue(args[2] instanceof ActionListener); + + ActionListener listener = (ActionListener) args[2]; + + assertTrue(listener != null); + if (mode == DeleteExecutionMode.INDEX_NOT_FOUND) { + listener.onFailure(new IndexNotFoundException(ADCommonName.CHECKPOINT_INDEX_NAME)); + } else if (mode == DeleteExecutionMode.FAILURE) { + listener.onFailure(new OpenSearchException("")); + } else { + if (mode == DeleteExecutionMode.PARTIAL_FAILURE) { + when(deleteByQueryResponse.getSearchFailures()) + .thenReturn( + Collections + .singletonList(new ScrollableHitSource.SearchFailure(new OpenSearchException("foo"), "bar", 1, "blah")) + ); + } + listener.onResponse(deleteByQueryResponse); + } + + return null; + }).when(client).execute(eq(DeleteByQueryAction.INSTANCE), any(), any()); + + checkpointDao.deleteModelCheckpointByDetectorId(detectorId); + } + + public void testDeleteSingleNormal() throws Exception { + delete_by_detector_id_template(DeleteExecutionMode.NORMAL); + assertTrue(testAppender.containsMessage(CheckpointDao.DOC_GOT_DELETED_LOG_MSG)); + } + + public void testDeleteSingleIndexNotFound() throws Exception { + delete_by_detector_id_template(DeleteExecutionMode.INDEX_NOT_FOUND); + assertTrue(testAppender.containsMessage(CheckpointDao.INDEX_DELETED_LOG_MSG)); + } + + public void testDeleteSingleResultFailure() throws Exception { + delete_by_detector_id_template(DeleteExecutionMode.FAILURE); + assertTrue(testAppender.containsMessage(CheckpointDao.NOT_ABLE_TO_DELETE_LOG_MSG)); + } + + public void testDeleteSingleResultPartialFailure() throws Exception { + delete_by_detector_id_template(DeleteExecutionMode.PARTIAL_FAILURE); + assertTrue(testAppender.containsMessage(CheckpointDao.SEARCH_FAILURE_LOG_MSG)); + assertTrue(testAppender.containsMessage(CheckpointDao.DOC_GOT_DELETED_LOG_MSG)); + } +} diff --git a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java-e b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java-e new file mode 100644 index 000000000..34265b0e6 --- /dev/null +++ b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java-e @@ -0,0 +1,829 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Clock; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Queue; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.ml.ModelManager.ModelType; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.collect.Tuple; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + +import test.org.opensearch.ad.util.LabelledAnomalyGenerator; +import test.org.opensearch.ad.util.MLUtil; +import test.org.opensearch.ad.util.MultiDimDataWithTime; + +import com.amazon.randomcutforest.config.Precision; +import com.amazon.randomcutforest.parkservices.AnomalyDescriptor; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; +import com.google.common.collect.ImmutableList; + +public class EntityColdStarterTests extends AbstractCosineDataTest { + + @BeforeClass + public static void initOnce() { + ClusterService clusterService = mock(ClusterService.class); + + Set> settingSet = ADEnabledSetting.settings.values().stream().collect(Collectors.toSet()); + + when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, settingSet)); + + ADEnabledSetting.getInstance().init(clusterService); + } + + @AfterClass + public static void clearOnce() { + // restore to default value + ADEnabledSetting.getInstance().setSettingValue(ADEnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, false); + } + + @Override + public void setUp() throws Exception { + super.setUp(); + ADEnabledSetting.getInstance().setSettingValue(ADEnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, Boolean.TRUE); + } + + @Override + public void tearDown() throws Exception { + ADEnabledSetting.getInstance().setSettingValue(ADEnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, Boolean.FALSE); + super.tearDown(); + } + + // train using samples directly + public void testTrainUsingSamples() throws InterruptedException { + Queue samples = MLUtil.createQueueSamples(numMinSamples); + EntityModel model = new EntityModel(entity, samples, null); + modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + entityColdStarter.trainModel(entity, detectorId, modelState, listener); + assertTrue(model.getTrcf().isPresent()); + ThresholdedRandomCutForest ercf = model.getTrcf().get(); + assertEquals(numMinSamples, ercf.getForest().getTotalUpdates()); + + checkSemaphoreRelease(); + } + + public void testColdStart() throws InterruptedException, IOException { + Queue samples = MLUtil.createQueueSamples(1); + EntityModel model = new EntityModel(entity, samples, null); + modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(Optional.of(1602269260000L)); + return null; + }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + + List> coldStartSamples = new ArrayList<>(); + + double[] sample1 = new double[] { 57.0 }; + double[] sample2 = new double[] { 1.0 }; + double[] sample3 = new double[] { -19.0 }; + + coldStartSamples.add(Optional.of(sample1)); + coldStartSamples.add(Optional.of(sample2)); + coldStartSamples.add(Optional.of(sample3)); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(4); + listener.onResponse(coldStartSamples); + return null; + }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); + + entityColdStarter.trainModel(entity, detectorId, modelState, listener); + checkSemaphoreRelease(); + + assertTrue(model.getTrcf().isPresent()); + ThresholdedRandomCutForest ercf = model.getTrcf().get(); + // 1 round: stride * (samples - 1) + 1 = 60 * 2 + 1 = 121 + // plus 1 existing sample + assertEquals(121, ercf.getForest().getTotalUpdates()); + assertTrue("size: " + model.getSamples().size(), model.getSamples().isEmpty()); + + checkSemaphoreRelease(); + + released.set(false); + // too frequent cold start of the same detector will fail + samples = MLUtil.createQueueSamples(1); + model = new EntityModel(entity, samples, null); + entityColdStarter.trainModel(entity, detectorId, modelState, listener); + + assertFalse(model.getTrcf().isPresent()); + // the samples is not touched since cold start does not happen + assertEquals("size: " + model.getSamples().size(), 1, model.getSamples().size()); + checkSemaphoreRelease(); + + List expectedColdStartData = new ArrayList<>(); + + // for function interpolate: + // 1st parameter is a matrix of size numFeatures * numSamples + // 2nd parameter is the number of interpolants including two samples + double[][] interval1 = imputer.impute(new double[][] { new double[] { sample1[0], sample2[0] } }, 61); + expectedColdStartData.addAll(convertToFeatures(interval1, 60)); + double[][] interval2 = imputer.impute(new double[][] { new double[] { sample2[0], sample3[0] } }, 61); + expectedColdStartData.addAll(convertToFeatures(interval2, 61)); + assertEquals(121, expectedColdStartData.size()); + + diffTesting(modelState, expectedColdStartData); + } + + // min max: miss one + public void testMissMin() throws IOException, InterruptedException { + Queue samples = MLUtil.createQueueSamples(1); + EntityModel model = new EntityModel(entity, samples, null); + modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(Optional.empty()); + return null; + }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + + entityColdStarter.trainModel(entity, detectorId, modelState, listener); + + verify(searchFeatureDao, never()).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); + + assertTrue(!model.getTrcf().isPresent()); + checkSemaphoreRelease(); + } + + /** + * Performan differential testing using trcf model with input cold start data and the modelState + * @param modelState an initialized model state + * @param coldStartData cold start data that initialized the modelState + */ + private void diffTesting(ModelState modelState, List coldStartData) { + int inputDimension = detector.getEnabledFeatureIds().size(); + + ThresholdedRandomCutForest refTRcf = ThresholdedRandomCutForest + .builder() + .compact(true) + .dimensions(inputDimension * detector.getShingleSize()) + .precision(Precision.FLOAT_32) + .randomSeed(rcfSeed) + .numberOfTrees(AnomalyDetectorSettings.NUM_TREES) + .shingleSize(detector.getShingleSize()) + .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) + .timeDecay(AnomalyDetectorSettings.TIME_DECAY) + .outputAfter(numMinSamples) + .initialAcceptFraction(0.125d) + .parallelExecutionEnabled(false) + .sampleSize(AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE) + .internalShinglingEnabled(true) + .anomalyRate(1 - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE) + .build(); + + for (int i = 0; i < coldStartData.size(); i++) { + refTRcf.process(coldStartData.get(i), 0); + } + assertEquals( + "Expect " + coldStartData.size() + " but got " + refTRcf.getForest().getTotalUpdates(), + coldStartData.size(), + refTRcf.getForest().getTotalUpdates() + ); + + Random r = new Random(); + + // make sure we trained the expected models + for (int i = 0; i < 100; i++) { + double[] point = r.ints(inputDimension, 0, 50).asDoubleStream().toArray(); + AnomalyDescriptor descriptor = refTRcf.process(point, 0); + ThresholdingResult result = modelManager + .getAnomalyResultForEntity(point, modelState, modelId, entity, detector.getShingleSize()); + assertEquals(descriptor.getRCFScore(), result.getRcfScore(), 1e-10); + assertEquals(descriptor.getAnomalyGrade(), result.getGrade(), 1e-10); + } + } + + /** + * Convert a double array of size numFeatures * numSamples to a double array of + * size numSamples * numFeatures + * @param interval input array + * @param numValsToKeep number of samples to keep in the input array. Used to + * keep the last sample in the input array out in case of repeated inclusion + * @return converted value + */ + private List convertToFeatures(double[][] interval, int numValsToKeep) { + List ret = new ArrayList<>(); + for (int j = 0; j < numValsToKeep; j++) { + ret.add(new double[] { interval[0][j] }); + } + return ret; + } + + // two segments of samples, one segment has 3 samples, while another one has only 1 + public void testTwoSegmentsWithSingleSample() throws InterruptedException, IOException { + Queue samples = MLUtil.createQueueSamples(1); + double[] savedSample = samples.peek(); + EntityModel model = new EntityModel(entity, samples, null); + modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(Optional.of(1602269260000L)); + return null; + }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + + List> coldStartSamples = new ArrayList<>(); + double[] sample1 = new double[] { 57.0 }; + double[] sample2 = new double[] { 1.0 }; + double[] sample3 = new double[] { -19.0 }; + double[] sample5 = new double[] { -17.0 }; + coldStartSamples.add(Optional.of(sample1)); + coldStartSamples.add(Optional.of(sample2)); + coldStartSamples.add(Optional.of(sample3)); + coldStartSamples.add(Optional.empty()); + coldStartSamples.add(Optional.of(sample5)); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(4); + listener.onResponse(coldStartSamples); + return null; + }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); + + entityColdStarter.trainModel(entity, detectorId, modelState, listener); + checkSemaphoreRelease(); + assertTrue(model.getTrcf().isPresent()); + + // 1 round: stride * (samples - 1) + 1 = 60 * 4 + 1 = 241 + // if 241 < shingle size + numMinSamples, then another round is performed + assertEquals(241, modelState.getModel().getTrcf().get().getForest().getTotalUpdates()); + checkSemaphoreRelease(); + + List expectedColdStartData = new ArrayList<>(); + + // for function interpolate: + // 1st parameter is a matrix of size numFeatures * numSamples + // 2nd parameter is the number of interpolants including two samples + double[][] interval1 = imputer.impute(new double[][] { new double[] { sample1[0], sample2[0] } }, 61); + expectedColdStartData.addAll(convertToFeatures(interval1, 60)); + double[][] interval2 = imputer.impute(new double[][] { new double[] { sample2[0], sample3[0] } }, 61); + expectedColdStartData.addAll(convertToFeatures(interval2, 60)); + double[][] interval3 = imputer.impute(new double[][] { new double[] { sample3[0], sample5[0] } }, 121); + expectedColdStartData.addAll(convertToFeatures(interval3, 121)); + assertTrue("size: " + model.getSamples().size(), model.getSamples().isEmpty()); + assertEquals(241, expectedColdStartData.size()); + diffTesting(modelState, expectedColdStartData); + } + + // two segments of samples, one segment has 3 samples, while another one 2 samples + public void testTwoSegments() throws InterruptedException, IOException { + Queue samples = MLUtil.createQueueSamples(1); + double[] savedSample = samples.peek(); + EntityModel model = new EntityModel(entity, samples, null); + modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(Optional.of(1602269260000L)); + return null; + }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + + List> coldStartSamples = new ArrayList<>(); + double[] sample1 = new double[] { 57.0 }; + double[] sample2 = new double[] { 1.0 }; + double[] sample3 = new double[] { -19.0 }; + double[] sample5 = new double[] { -17.0 }; + double[] sample6 = new double[] { -38.0 }; + coldStartSamples.add(Optional.of(new double[] { 57.0 })); + coldStartSamples.add(Optional.of(new double[] { 1.0 })); + coldStartSamples.add(Optional.of(new double[] { -19.0 })); + coldStartSamples.add(Optional.empty()); + coldStartSamples.add(Optional.of(new double[] { -17.0 })); + coldStartSamples.add(Optional.of(new double[] { -38.0 })); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(4); + listener.onResponse(coldStartSamples); + return null; + }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); + + entityColdStarter.trainModel(entity, detectorId, modelState, listener); + checkSemaphoreRelease(); + + assertTrue(model.getTrcf().isPresent()); + ThresholdedRandomCutForest ercf = model.getTrcf().get(); + // 1 rounds: stride * (samples - 1) + 1 = 60 * 5 + 1 = 301 + assertEquals(301, ercf.getForest().getTotalUpdates()); + checkSemaphoreRelease(); + + List expectedColdStartData = new ArrayList<>(); + + // for function interpolate: + // 1st parameter is a matrix of size numFeatures * numSamples + // 2nd parameter is the number of interpolants including two samples + double[][] interval1 = imputer.impute(new double[][] { new double[] { sample1[0], sample2[0] } }, 61); + expectedColdStartData.addAll(convertToFeatures(interval1, 60)); + double[][] interval2 = imputer.impute(new double[][] { new double[] { sample2[0], sample3[0] } }, 61); + expectedColdStartData.addAll(convertToFeatures(interval2, 60)); + double[][] interval3 = imputer.impute(new double[][] { new double[] { sample3[0], sample5[0] } }, 121); + expectedColdStartData.addAll(convertToFeatures(interval3, 120)); + double[][] interval4 = imputer.impute(new double[][] { new double[] { sample5[0], sample6[0] } }, 61); + expectedColdStartData.addAll(convertToFeatures(interval4, 61)); + assertEquals(301, expectedColdStartData.size()); + assertTrue("size: " + model.getSamples().size(), model.getSamples().isEmpty()); + diffTesting(modelState, expectedColdStartData); + } + + public void testThrottledColdStart() throws InterruptedException { + Queue samples = MLUtil.createQueueSamples(1); + EntityModel model = new EntityModel(entity, samples, null); + modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onFailure(new OpenSearchRejectedExecutionException("")); + return null; + }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + + entityColdStarter.trainModel(entity, detectorId, modelState, listener); + + entityColdStarter.trainModel(entity, "456", modelState, listener); + + // only the first one makes the call + verify(searchFeatureDao, times(1)).getEntityMinDataTime(any(), any(), any()); + checkSemaphoreRelease(); + } + + public void testColdStartException() throws InterruptedException { + Queue samples = MLUtil.createQueueSamples(1); + EntityModel model = new EntityModel(entity, samples, null); + modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onFailure(new TimeSeriesException(detectorId, "")); + return null; + }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + + entityColdStarter.trainModel(entity, detectorId, modelState, listener); + + assertTrue(stateManager.getLastDetectionError(detectorId) != null); + checkSemaphoreRelease(); + } + + @SuppressWarnings("unchecked") + public void testNotEnoughSamples() throws InterruptedException, IOException { + Queue samples = MLUtil.createQueueSamples(1); + EntityModel model = new EntityModel(entity, samples, null); + modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + + detector = TestHelpers.AnomalyDetectorBuilder + .newInstance() + .setDetectionInterval(new IntervalTimeConfiguration(13, ChronoUnit.MINUTES)) + .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) + .build(); + doAnswer(invocation -> { + GetRequest request = invocation.getArgument(0); + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, CommonName.CONFIG_INDEX)); + return null; + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(Optional.of(1602269260000L)); + return null; + }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + + List> coldStartSamples = new ArrayList<>(); + coldStartSamples.add(Optional.of(new double[] { 57.0 })); + coldStartSamples.add(Optional.of(new double[] { 1.0 })); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(4); + listener.onResponse(coldStartSamples); + return null; + }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); + + entityColdStarter.trainModel(entity, detectorId, modelState, listener); + checkSemaphoreRelease(); + + assertTrue(!model.getTrcf().isPresent()); + // 1st round we add 57 and 1. + // 2nd round we add 57 and 1. + Queue currentSamples = model.getSamples(); + assertEquals("real sample size is " + currentSamples.size(), 4, currentSamples.size()); + int j = 0; + while (!currentSamples.isEmpty()) { + double[] element = currentSamples.poll(); + assertEquals(1, element.length); + if (j == 0 || j == 2) { + assertEquals(57, element[0], 1e-10); + } else { + assertEquals(1, element[0], 1e-10); + } + j++; + } + } + + @SuppressWarnings("unchecked") + public void testEmptyDataRange() throws InterruptedException { + Queue samples = MLUtil.createQueueSamples(1); + EntityModel model = new EntityModel(entity, samples, null); + modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + + // the min-max range 894056973000L~894057860000L is too small and thus no data range can be found + when(clock.millis()).thenReturn(894057860000L); + + doAnswer(invocation -> { + GetRequest request = invocation.getArgument(0); + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(TestHelpers.createGetResponse(detector, detector.getId(), CommonName.CONFIG_INDEX)); + return null; + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(Optional.of(894056973000L)); + return null; + }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + + entityColdStarter.trainModel(entity, detectorId, modelState, listener); + checkSemaphoreRelease(); + + assertTrue(!model.getTrcf().isPresent()); + // the min-max range is too small and thus no data range can be found + assertEquals("real sample size is " + model.getSamples().size(), 1, model.getSamples().size()); + } + + public void testTrainModelFromExistingSamplesEnoughSamples() { + int inputDimension = 2; + int dimensions = inputDimension * detector.getShingleSize(); + + ThresholdedRandomCutForest.Builder rcfConfig = ThresholdedRandomCutForest + .builder() + .compact(true) + .dimensions(dimensions) + .precision(Precision.FLOAT_32) + .randomSeed(rcfSeed) + .numberOfTrees(AnomalyDetectorSettings.NUM_TREES) + .shingleSize(detector.getShingleSize()) + .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) + .timeDecay(AnomalyDetectorSettings.TIME_DECAY) + .outputAfter(numMinSamples) + .initialAcceptFraction(0.125d) + .parallelExecutionEnabled(false) + .sampleSize(AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE) + .internalShinglingEnabled(true) + .anomalyRate(1 - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE); + Tuple, ThresholdedRandomCutForest> models = MLUtil.prepareModel(inputDimension, rcfConfig); + Queue samples = models.v1(); + ThresholdedRandomCutForest rcf = models.v2(); + + EntityModel model = new EntityModel(entity, samples, null); + modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + + Random r = new Random(); + + // make sure we trained the expected models + for (int i = 0; i < 100; i++) { + double[] point = r.ints(inputDimension, 0, 50).asDoubleStream().toArray(); + AnomalyDescriptor descriptor = rcf.process(point, 0); + ThresholdingResult result = modelManager + .getAnomalyResultForEntity(point, modelState, modelId, entity, detector.getShingleSize()); + assertEquals(descriptor.getRCFScore(), result.getRcfScore(), 1e-10); + assertEquals(descriptor.getAnomalyGrade(), result.getGrade(), 1e-10); + } + } + + public void testTrainModelFromExistingSamplesNotEnoughSamples() { + Queue samples = new ArrayDeque<>(); + EntityModel model = new EntityModel(entity, samples, null); + modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + entityColdStarter.trainModelFromExistingSamples(modelState, detector.getShingleSize()); + assertTrue(!modelState.getModel().getTrcf().isPresent()); + } + + @SuppressWarnings("unchecked") + private void accuracyTemplate(int detectorIntervalMins, float precisionThreshold, float recallThreshold) throws Exception { + int baseDimension = 2; + int dataSize = 20 * AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE; + int trainTestSplit = 300; + // detector interval + int interval = detectorIntervalMins; + int delta = 60000 * interval; + + int numberOfTrials = 20; + double prec = 0; + double recall = 0; + for (int z = 0; z < numberOfTrials; z++) { + // set up detector + detector = TestHelpers.AnomalyDetectorBuilder + .newInstance() + .setDetectionInterval(new IntervalTimeConfiguration(interval, ChronoUnit.MINUTES)) + .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) + .setShingleSize(TimeSeriesSettings.DEFAULT_SHINGLE_SIZE) + .build(); + + long seed = new Random().nextLong(); + LOG.info("seed = " + seed); + // create labelled data + MultiDimDataWithTime dataWithKeys = LabelledAnomalyGenerator + .getMultiDimData( + dataSize + detector.getShingleSize() - 1, + 50, + 100, + 5, + seed, + baseDimension, + false, + trainTestSplit, + delta, + false + ); + long[] timestamps = dataWithKeys.timestampsMs; + double[][] data = dataWithKeys.data; + when(clock.millis()).thenReturn(timestamps[trainTestSplit - 1]); + + // training data ranges from timestamps[0] ~ timestamps[trainTestSplit-1] + doAnswer(invocation -> { + GetRequest request = invocation.getArgument(0); + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(TestHelpers.createGetResponse(detector, detector.getId(), CommonName.CONFIG_INDEX)); + return null; + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(Optional.of(timestamps[0])); + return null; + }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + + doAnswer(invocation -> { + List> ranges = invocation.getArgument(1); + List> coldStartSamples = new ArrayList<>(); + + Collections.sort(ranges, new Comparator>() { + @Override + public int compare(Entry p1, Entry p2) { + return Long.compare(p1.getKey(), p2.getKey()); + } + }); + for (int j = 0; j < ranges.size(); j++) { + Entry range = ranges.get(j); + Long start = range.getKey(); + int valueIndex = searchInsert(timestamps, start); + coldStartSamples.add(Optional.of(data[valueIndex])); + } + + ActionListener>> listener = invocation.getArgument(4); + listener.onResponse(coldStartSamples); + return null; + }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); + + EntityModel model = new EntityModel(entity, new ArrayDeque<>(), null); + modelState = new ModelState<>(model, modelId, detector.getId(), ModelType.ENTITY.getName(), clock, priority); + + released = new AtomicBoolean(); + + inProgressLatch = new CountDownLatch(1); + listener = ActionListener.wrap(() -> { + released.set(true); + inProgressLatch.countDown(); + }); + + entityColdStarter.trainModel(entity, detector.getId(), modelState, listener); + + checkSemaphoreRelease(); + assertTrue(model.getTrcf().isPresent()); + + int tp = 0; + int fp = 0; + int fn = 0; + long[] changeTimestamps = dataWithKeys.changeTimeStampsMs; + + for (int j = trainTestSplit; j < data.length; j++) { + ThresholdingResult result = modelManager + .getAnomalyResultForEntity(data[j], modelState, modelId, entity, detector.getShingleSize()); + if (result.getGrade() > 0) { + if (changeTimestamps[j] == 0) { + fp++; + } else { + tp++; + } + } else { + if (changeTimestamps[j] != 0) { + fn++; + } + // else ok + } + } + + if (tp + fp == 0) { + prec = 1; + } else { + prec = tp * 1.0 / (tp + fp); + } + + if (tp + fn == 0) { + recall = 1; + } else { + recall = tp * 1.0 / (tp + fn); + } + + // there are randomness involved; keep trying for a limited times + if (prec >= precisionThreshold && recall >= recallThreshold) { + break; + } + } + + assertTrue("precision is " + prec, prec >= precisionThreshold); + assertTrue("recall is " + recall, recall >= recallThreshold); + } + + public void testAccuracyTenMinuteInterval() throws Exception { + accuracyTemplate(10, 0.5f, 0.5f); + } + + public void testAccuracyThirteenMinuteInterval() throws Exception { + accuracyTemplate(13, 0.5f, 0.5f); + } + + public void testAccuracyOneMinuteIntervalNoInterpolation() throws Exception { + ADEnabledSetting.getInstance().setSettingValue(ADEnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, false); + // for one minute interval, we need to disable interpolation to achieve good results + entityColdStarter = new EntityColdStarter( + clock, + threadPool, + stateManager, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.TIME_DECAY, + numMinSamples, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + imputer, + searchFeatureDao, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + featureManager, + settings, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + rcfSeed, + AnomalyDetectorSettings.MAX_COLD_START_ROUNDS + ); + + modelManager = new ModelManager( + mock(CheckpointDao.class), + mock(Clock.class), + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.TIME_DECAY, + AnomalyDetectorSettings.NUM_MIN_SAMPLES, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + AnomalyDetectorSettings.MIN_PREVIEW_SIZE, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, + entityColdStarter, + mock(FeatureManager.class), + mock(MemoryTracker.class), + settings, + clusterService + ); + + accuracyTemplate(1, 0.6f, 0.6f); + } + + private ModelState createStateForCacheRelease() { + inProgressLatch = new CountDownLatch(1); + releaseSemaphore = () -> { + released.set(true); + inProgressLatch.countDown(); + }; + listener = ActionListener.wrap(releaseSemaphore); + Queue samples = MLUtil.createQueueSamples(1); + EntityModel model = new EntityModel(entity, samples, null); + return new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + } + + public void testCacheReleaseAfterMaintenance() throws IOException, InterruptedException { + ModelState modelState = createStateForCacheRelease(); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(Optional.of(1602269260000L)); + return null; + }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + + List> coldStartSamples = new ArrayList<>(); + + double[] sample1 = new double[] { 57.0 }; + double[] sample2 = new double[] { 1.0 }; + double[] sample3 = new double[] { -19.0 }; + + coldStartSamples.add(Optional.of(sample1)); + coldStartSamples.add(Optional.of(sample2)); + coldStartSamples.add(Optional.of(sample3)); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(4); + listener.onResponse(coldStartSamples); + return null; + }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); + + entityColdStarter.trainModel(entity, detectorId, modelState, listener); + checkSemaphoreRelease(); + assertTrue(modelState.getModel().getTrcf().isPresent()); + + modelState = createStateForCacheRelease(); + entityColdStarter.trainModel(entity, detectorId, modelState, listener); + checkSemaphoreRelease(); + // model is not trained as the door keeper remembers it and won't retry training + assertTrue(!modelState.getModel().getTrcf().isPresent()); + + // make sure when the next maintenance coming, current door keeper gets reset + // note our detector interval is 1 minute and the door keeper will expire in 60 intervals, which are 60 minutes + when(clock.instant()).thenReturn(Instant.now().plus(TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ + 1, ChronoUnit.MINUTES)); + entityColdStarter.maintenance(); + + modelState = createStateForCacheRelease(); + entityColdStarter.trainModel(entity, detectorId, modelState, listener); + checkSemaphoreRelease(); + // model is trained as the door keeper gets reset + assertTrue(modelState.getModel().getTrcf().isPresent()); + } + + public void testCacheReleaseAfterClear() throws IOException, InterruptedException { + ModelState modelState = createStateForCacheRelease(); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(Optional.of(1602269260000L)); + return null; + }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + + List> coldStartSamples = new ArrayList<>(); + + double[] sample1 = new double[] { 57.0 }; + double[] sample2 = new double[] { 1.0 }; + double[] sample3 = new double[] { -19.0 }; + + coldStartSamples.add(Optional.of(sample1)); + coldStartSamples.add(Optional.of(sample2)); + coldStartSamples.add(Optional.of(sample3)); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(4); + listener.onResponse(coldStartSamples); + return null; + }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); + + entityColdStarter.trainModel(entity, detectorId, modelState, listener); + checkSemaphoreRelease(); + assertTrue(modelState.getModel().getTrcf().isPresent()); + + entityColdStarter.clear(detectorId); + + modelState = createStateForCacheRelease(); + entityColdStarter.trainModel(entity, detectorId, modelState, listener); + checkSemaphoreRelease(); + // model is trained as the door keeper is regenerated after clearance + assertTrue(modelState.getModel().getTrcf().isPresent()); + } +} diff --git a/src/test/java/org/opensearch/ad/ml/EntityModelTests.java-e b/src/test/java/org/opensearch/ad/ml/EntityModelTests.java-e new file mode 100644 index 000000000..1f4afe829 --- /dev/null +++ b/src/test/java/org/opensearch/ad/ml/EntityModelTests.java-e @@ -0,0 +1,73 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import java.util.ArrayDeque; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.test.OpenSearchTestCase; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public class EntityModelTests extends OpenSearchTestCase { + + private ThresholdedRandomCutForest trcf; + + @Before + public void setup() { + this.trcf = new ThresholdedRandomCutForest(ThresholdedRandomCutForest.builder().dimensions(2).internalShinglingEnabled(true)); + } + + public void testNullInternalSampleQueue() { + EntityModel model = new EntityModel(null, null, null); + model.addSample(new double[] { 0.8 }); + assertEquals(1, model.getSamples().size()); + } + + public void testNullInputSample() { + EntityModel model = new EntityModel(null, null, null); + model.addSample(null); + assertEquals(0, model.getSamples().size()); + } + + public void testEmptyInputSample() { + EntityModel model = new EntityModel(null, null, null); + model.addSample(new double[] {}); + assertEquals(0, model.getSamples().size()); + } + + @Test + public void trcf_constructor() { + EntityModel em = new EntityModel(null, new ArrayDeque<>(), trcf); + assertEquals(trcf, em.getTrcf().get()); + } + + @Test + public void clear() { + EntityModel em = new EntityModel(null, new ArrayDeque<>(), trcf); + + em.clear(); + + assertTrue(em.getSamples().isEmpty()); + assertFalse(em.getTrcf().isPresent()); + } + + @Test + public void setTrcf() { + EntityModel em = new EntityModel(null, null, null); + assertFalse(em.getTrcf().isPresent()); + + em.setTrcf(this.trcf); + assertTrue(em.getTrcf().isPresent()); + } +} diff --git a/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java b/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java index f9009b18c..6fd32c2c9 100644 --- a/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java +++ b/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java @@ -34,7 +34,6 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.MemoryTracker; import org.opensearch.ad.feature.FeatureManager; import org.opensearch.ad.feature.SearchFeatureDao; @@ -44,6 +43,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.test.ClusterServiceUtils; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.IntervalTimeConfiguration; @@ -126,7 +126,7 @@ private void averageAccuracyTemplate( AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, AnomalyDetectorSettings.HOURLY_MAINTENANCE, threadPool, - AnomalyDetectorPlugin.AD_THREAD_POOL_NAME + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME ); entityColdStarter = new EntityColdStarter( diff --git a/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java-e b/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java-e new file mode 100644 index 000000000..6fd32c2c9 --- /dev/null +++ b/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java-e @@ -0,0 +1,343 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.when; + +import java.time.Clock; +import java.time.temporal.ChronoUnit; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.apache.lucene.tests.util.TimeUnits; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.ml.ModelManager.ModelType; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.test.ClusterServiceUtils; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + +import test.org.opensearch.ad.util.LabelledAnomalyGenerator; +import test.org.opensearch.ad.util.MultiDimDataWithTime; + +import com.carrotsearch.randomizedtesting.annotations.TimeoutSuite; +import com.google.common.collect.ImmutableList; + +@TimeoutSuite(millis = 60 * TimeUnits.MINUTE) // rcf may be slow due to bounding box cache disabled +public class HCADModelPerfTests extends AbstractCosineDataTest { + + /** + * A template to perform precision/recall test by simulating HCAD logic with only one entity. + * + * @param detectorIntervalMins Detector interval + * @param precisionThreshold precision threshold + * @param recallThreshold recall threshold + * @param baseDimension the number of dimensions + * @param anomalyIndependent whether anomalies in each dimension is generated independently + * @throws Exception when failing to create anomaly detector or creating training data + */ + @SuppressWarnings("unchecked") + private void averageAccuracyTemplate( + int detectorIntervalMins, + float precisionThreshold, + float recallThreshold, + int baseDimension, + boolean anomalyIndependent + ) throws Exception { + int dataSize = 20 * AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE; + int trainTestSplit = 300; + // detector interval + int interval = detectorIntervalMins; + int delta = 60000 * interval; + + int numberOfTrials = 10; + double prec = 0; + double recall = 0; + double totalPrec = 0; + double totalRecall = 0; + + // training data ranges from timestamps[0] ~ timestamps[trainTestSplit-1] + // set up detector + detector = TestHelpers.AnomalyDetectorBuilder + .newInstance() + .setDetectionInterval(new IntervalTimeConfiguration(interval, ChronoUnit.MINUTES)) + .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) + .setShingleSize(TimeSeriesSettings.DEFAULT_SHINGLE_SIZE) + .build(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + + listener.onResponse(TestHelpers.createGetResponse(detector, detector.getId(), CommonName.CONFIG_INDEX)); + return null; + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); + + for (int z = 1; z <= numberOfTrials; z++) { + long seed = z; + LOG.info("seed = " + seed); + // recreate in each loop; otherwise, we will have heap overflow issue. + searchFeatureDao = mock(SearchFeatureDao.class); + clusterSettings = new ClusterSettings(Settings.EMPTY, nodestateSetting); + clusterService = ClusterServiceUtils.createClusterService(threadPool, discoveryNode, clusterSettings); + + featureManager = new FeatureManager( + searchFeatureDao, + imputer, + clock, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, + AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, + AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, + AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, + AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, + AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME + ); + + entityColdStarter = new EntityColdStarter( + clock, + threadPool, + stateManager, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.TIME_DECAY, + numMinSamples, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + imputer, + searchFeatureDao, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + featureManager, + settings, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + seed, + AnomalyDetectorSettings.MAX_COLD_START_ROUNDS + ); + + modelManager = new ModelManager( + mock(CheckpointDao.class), + mock(Clock.class), + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.TIME_DECAY, + AnomalyDetectorSettings.NUM_MIN_SAMPLES, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + AnomalyDetectorSettings.MIN_PREVIEW_SIZE, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, + entityColdStarter, + mock(FeatureManager.class), + mock(MemoryTracker.class), + settings, + clusterService + ); + + // create labelled data + MultiDimDataWithTime dataWithKeys = LabelledAnomalyGenerator + .getMultiDimData( + dataSize + detector.getShingleSize() - 1, + 50, + 100, + 5, + seed, + baseDimension, + false, + trainTestSplit, + delta, + anomalyIndependent + ); + + long[] timestamps = dataWithKeys.timestampsMs; + double[][] data = dataWithKeys.data; + when(clock.millis()).thenReturn(timestamps[trainTestSplit - 1]); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(Optional.of(timestamps[0])); + return null; + }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + + doAnswer(invocation -> { + List> ranges = invocation.getArgument(1); + List> coldStartSamples = new ArrayList<>(); + + Collections.sort(ranges, new Comparator>() { + @Override + public int compare(Entry p1, Entry p2) { + return Long.compare(p1.getKey(), p2.getKey()); + } + }); + for (int j = 0; j < ranges.size(); j++) { + Entry range = ranges.get(j); + Long start = range.getKey(); + int valueIndex = searchInsert(timestamps, start); + coldStartSamples.add(Optional.of(data[valueIndex])); + } + + ActionListener>> listener = invocation.getArgument(4); + listener.onResponse(coldStartSamples); + return null; + }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); + + entity = Entity.createSingleAttributeEntity("field", entityName + z); + EntityModel model = new EntityModel(entity, new ArrayDeque<>(), null); + ModelState modelState = new ModelState<>( + model, + entity.getModelId(detectorId).get(), + detector.getId(), + ModelType.ENTITY.getName(), + clock, + priority + ); + + released = new AtomicBoolean(); + + inProgressLatch = new CountDownLatch(1); + listener = ActionListener.wrap(() -> { + released.set(true); + inProgressLatch.countDown(); + }); + + entityColdStarter.trainModel(entity, detector.getId(), modelState, listener); + + checkSemaphoreRelease(); + assertTrue(model.getTrcf().isPresent()); + + int tp = 0; + int fp = 0; + int fn = 0; + long[] changeTimestamps = dataWithKeys.changeTimeStampsMs; + + for (int j = trainTestSplit; j < data.length; j++) { + ThresholdingResult result = modelManager + .getAnomalyResultForEntity(data[j], modelState, modelId, entity, detector.getShingleSize()); + if (result.getGrade() > 0) { + if (changeTimestamps[j] == 0) { + fp++; + } else { + tp++; + } + } else { + if (changeTimestamps[j] != 0) { + fn++; + } + // else ok + } + } + + if (tp + fp == 0) { + prec = 1; + } else { + prec = tp * 1.0 / (tp + fp); + } + + if (tp + fn == 0) { + recall = 1; + } else { + recall = tp * 1.0 / (tp + fn); + } + + totalPrec += prec; + totalRecall += recall; + modelState = null; + dataWithKeys = null; + reset(searchFeatureDao); + searchFeatureDao = null; + clusterService = null; + } + + double avgPrec = totalPrec / numberOfTrials; + double avgRecall = totalRecall / numberOfTrials; + LOG.info("{} features, Interval {}, Precision: {}, recall: {}", baseDimension, detectorIntervalMins, avgPrec, avgRecall); + assertTrue("average precision is " + avgPrec, avgPrec >= precisionThreshold); + assertTrue("average recall is " + avgRecall, avgRecall >= recallThreshold); + } + + /** + * Split average accuracy tests into two in case of time out per test. + * @throws Exception when failing to perform tests + */ + public void testAverageAccuracyDependent() throws Exception { + LOG.info("Anomalies are injected dependently"); + + // 10 minute interval, 4 features + averageAccuracyTemplate(10, 0.4f, 0.3f, 4, false); + + // 10 minute interval, 2 features + averageAccuracyTemplate(10, 0.4f, 0.4f, 2, false); + + // 10 minute interval, 1 features + averageAccuracyTemplate(10, 0.4f, 0.4f, 1, false); + + // 5 minute interval, 4 features + averageAccuracyTemplate(5, 0.4f, 0.3f, 4, false); + + // 5 minute interval, 2 features + averageAccuracyTemplate(5, 0.4f, 0.4f, 2, false); + + // 5 minute interval, 1 features + averageAccuracyTemplate(5, 0.4f, 0.4f, 1, false); + } + + /** + * Split average accuracy tests into two in case of time out per test. + * @throws Exception when failing to perform tests + */ + public void testAverageAccuracyIndependent() throws Exception { + LOG.info("Anomalies are injected independently"); + + // 10 minute interval, 4 features + averageAccuracyTemplate(10, 0.3f, 0.1f, 4, true); + + // 10 minute interval, 2 features + averageAccuracyTemplate(10, 0.4f, 0.4f, 2, true); + + // 10 minute interval, 1 features + averageAccuracyTemplate(10, 0.3f, 0.4f, 1, true); + + // 5 minute interval, 4 features + averageAccuracyTemplate(5, 0.2f, 0.1f, 4, true); + + // 5 minute interval, 2 features + averageAccuracyTemplate(5, 0.4f, 0.4f, 2, true); + + // 5 minute interval, 1 features + averageAccuracyTemplate(5, 0.3f, 0.4f, 1, true); + } +} diff --git a/src/test/java/org/opensearch/ad/ml/HybridThresholdingModelTests.java-e b/src/test/java/org/opensearch/ad/ml/HybridThresholdingModelTests.java-e new file mode 100644 index 000000000..be8ff9525 --- /dev/null +++ b/src/test/java/org/opensearch/ad/ml/HybridThresholdingModelTests.java-e @@ -0,0 +1,227 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.util.Arrays; + +import junitparams.JUnitParamsRunner; +import junitparams.Parameters; + +import org.apache.commons.math3.distribution.NormalDistribution; +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(JUnitParamsRunner.class) +public class HybridThresholdingModelTests { + + /* + * Returns samples from the log-normal distribution. + * + * Given random Gaussian samples X ~ N(mu, sigma), samples from the + * log-normal distribution are given by Y ~ e^X. + * + * @param sampleSize number of log-normal samples to generate + * @param mu mean + * @param sigma standard deviation + */ + private double[] logNormalSamples(int sampleSize, double mu, double sigma) { + NormalDistribution distribution = new NormalDistribution(mu, sigma); + distribution.reseedRandomGenerator(0L); + + double[] samples = new double[sampleSize]; + for (int i = 0; i < sampleSize; i++) { + samples[i] = Math.exp(distribution.sample()); + } + return samples; + } + + private Object[] getTestGettersParameters() { + double minPvalueThreshold = 0.8; + double maxRankError = 0.001; + double maxScore = 10; + int numLogNormalQuantiles = 0; + int downsampleNumSamples = 100_000; + long downsampleMaxNumObservations = 10_000_000L; + HybridThresholdingModel model = new HybridThresholdingModel( + minPvalueThreshold, + maxRankError, + maxScore, + numLogNormalQuantiles, + downsampleNumSamples, + downsampleMaxNumObservations + ); + + return new Object[] { + new Object[] { + model, + minPvalueThreshold, + maxRankError, + maxScore, + numLogNormalQuantiles, + downsampleNumSamples, + downsampleMaxNumObservations } }; + } + + @Test + @Parameters(method = "getTestGettersParameters") + public void testGetters( + HybridThresholdingModel model, + double minPvalueThreshold, + double maxRankError, + double maxScore, + int numLogNormalQuantiles, + int downsampleNumSamples, + long downsampleMaxNumObservations + ) { + double delta = 1e-4; + assertEquals(minPvalueThreshold, model.getMinPvalueThreshold(), delta); + assertEquals(maxRankError, model.getMaxRankError(), delta); + assertEquals(maxScore, model.getMaxScore(), delta); + assertEquals(numLogNormalQuantiles, model.getNumLogNormalQuantiles()); + assertEquals(downsampleNumSamples, model.getDownsampleNumSamples()); + assertEquals(downsampleMaxNumObservations, model.getDownsampleMaxNumObservations()); + } + + @Test + public void emptyConstructor_returnNonNullInstance() { + assertTrue(new HybridThresholdingModel() != null); + } + + private Object[] getThrowsExpectedInitializationExceptionParameters() { + return new Object[] { + new Object[] { 0.0, 0.001, 10, 10, 100, 1000 }, + new Object[] { 1.0, 0.001, 10, 10, 100, 1000 }, + new Object[] { 0.9, 0.123, 10, 10, 100, 1000 }, + new Object[] { 0.9, -0.01, 10, 10, 100, 1000 }, + new Object[] { 0.9, 0.001, -8, 10, 100, 1000 }, + new Object[] { 0.9, 0.001, 10, -1, 100, 1000 }, + new Object[] { 0.9, 0.001, 10, 10, 1, 1000 }, + new Object[] { 0.9, 0.001, 10, 10, 0, 1000 }, + new Object[] { 0.9, 0.001, 10, 10, 10_000, 1000 }, }; + } + + @Test(expected = IllegalArgumentException.class) + @Parameters(method = "getThrowsExpectedInitializationExceptionParameters") + public void throwsExpectedInitializationExceptions( + double minPvalueThreshold, + double maxRankError, + double maxScore, + int numLogNormalQuantiles, + int downsampleNumSamples, + int downsampleMaxNumObservations + ) { + HybridThresholdingModel invalidModel = new HybridThresholdingModel( + minPvalueThreshold, + maxRankError, + maxScore, + numLogNormalQuantiles, + downsampleNumSamples, + downsampleMaxNumObservations + ); + } + + private Object[] getTestExpectedGradesWithUpdateParameters() { + double mu = 1.2; + double sigma = 3.4; + double[] trainingAnomalyScores = logNormalSamples(1_000_000, 1.2, 3.4); + double maxScore = Arrays.stream(trainingAnomalyScores).max().getAsDouble(); + HybridThresholdingModel model = new HybridThresholdingModel(1e-8, 1e-5, maxScore, 10_000, 2, 5_000_000); + model.train(trainingAnomalyScores); + + return new Object[] { + new Object[] { + model, + new double[] {}, + new double[] { 0.0, Math.exp(mu), Math.exp(mu + sigma), maxScore, HybridThresholdingModel.MIN_SCORE }, + new double[] { 0.0, 0.5, 0.84134, 1.0, 0 }, }, + new Object[] { + model, + trainingAnomalyScores, + new double[] { 0.0, Math.exp(mu), Math.exp(mu + sigma), maxScore }, + new double[] { 0.0, 0.5, 0.84134, 1.0 }, }, + new Object[] { + new HybridThresholdingModel(1e-8, 1e-5, maxScore, 10_000, 2, 5_000_000), + new double[0], + new double[] { 1.0 }, + new double[] { 0.0 }, } }; + } + + @Test + @Parameters(method = "getTestExpectedGradesWithUpdateParameters") + public void testExpectedGradesWithUpdate( + HybridThresholdingModel model, + double[] updateAnomalyScores, + double[] testAnomalyScores, + double[] expectedGrades + ) { + double delta = 1e-3; + for (double anomalyScore : updateAnomalyScores) { + model.update(anomalyScore); + } + + for (int i = 0; i < testAnomalyScores.length; i++) { + double expectedGrade = expectedGrades[i]; + double actualGrade = model.grade(testAnomalyScores[i]); + assertEquals(expectedGrade, actualGrade, delta); + } + } + + private Object[] getTestConfidenceParameters() { + double maxRankError = 0.001; + double[] trainingAnomalyScores = logNormalSamples(1_000_000, 1.2, 3.4); + double maxScore = Arrays.stream(trainingAnomalyScores).max().getAsDouble(); + HybridThresholdingModel model = new HybridThresholdingModel(0.8, maxRankError, maxScore, 1000, 2, 5_000_000); + model.train(trainingAnomalyScores); + for (double anomalyScore : trainingAnomalyScores) { + model.update(anomalyScore); + } + + HybridThresholdingModel newModel = new HybridThresholdingModel(0.8, maxRankError, maxScore, 1000, 2, 5_000_000); + + return new Object[] { new Object[] { model, 0.99 }, new Object[] { newModel, 0.99 } }; + } + + @Test + @Parameters(method = "getTestConfidenceParameters") + public void testConfidence(HybridThresholdingModel model, double expectedConfidence) { + double delta = 1e-2; + assertEquals(expectedConfidence, model.confidence(), delta); + } + + private Object[] getTestDownsamplingParameters() { + double[] scores = { 1, 2, 3, 4, 5, 6, 7, 8, 9 }; + HybridThresholdingModel model = new HybridThresholdingModel(1e-8, 1e-4, 9, 0, 5, 9); + for (double score : scores) + model.update(score); + + return new Object[] { + new Object[] { + model, // model ECDF should be equal to [1, 3, 5, 7, 9] + new double[] { 1.0, 1.1, 2.0, 3.0, 4.0, 4.1, 5.0, 5.1, 7.1, 10.1 }, + new double[] { 0.0, 0.2, 0.2, 0.2, 0.4, 0.4, 0.4, 0.6, 0.8, 1.0 }, }, }; + } + + @Test + @Parameters(method = "getTestDownsamplingParameters") + public void testDownsampling(HybridThresholdingModel model, double[] scores, double[] expectedGrades) { + double[] actualGrades = new double[scores.length]; + for (int i = 0; i < actualGrades.length; i++) + actualGrades[i] = model.grade(scores[i]); + + final double delta = 0.001; + assertArrayEquals(expectedGrades, actualGrades, delta); + } +} diff --git a/src/test/java/org/opensearch/ad/ml/KllFloatsSketchSerDeTests.java-e b/src/test/java/org/opensearch/ad/ml/KllFloatsSketchSerDeTests.java-e new file mode 100644 index 000000000..3b72d9d1a --- /dev/null +++ b/src/test/java/org/opensearch/ad/ml/KllFloatsSketchSerDeTests.java-e @@ -0,0 +1,61 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import static org.junit.Assert.assertEquals; + +import java.util.Random; + +import org.junit.Before; +import org.junit.Test; + +import com.google.gson.Gson; + +public class KllFloatsSketchSerDeTests { + + private Gson gson; + + private HybridThresholdingModel hybridModel; + + @Before + public void setup() { + gson = new Gson(); + + hybridModel = new HybridThresholdingModel(/*minPvalueThreshold*/ 0.95, + /*maxRankError*/ 1e-4, + /*maxScore*/ 4, + /*numLogNormalQuantiles*/ 10000, + /*downsampleNumSamples*/ 100_000, + /*downsampleMaxNumObservations*/ 200_000L + ); + } + + @Test + public void serialize_deserialize_returnOriginalModel() { + hybridModel.train(new Random().doubles(10_000L, 0.1, 3.9).toArray()); + + String json = gson.toJson(hybridModel); + HybridThresholdingModel deserialized = gson.fromJson(json, HybridThresholdingModel.class); + + double delta = 1e-6; + assertEquals(hybridModel.getMinPvalueThreshold(), deserialized.getMinPvalueThreshold(), delta); + assertEquals(hybridModel.getMaxRankError(), deserialized.getMaxRankError(), delta); + assertEquals(hybridModel.getMaxScore(), deserialized.getMaxScore(), delta); + assertEquals(hybridModel.getNumLogNormalQuantiles(), deserialized.getNumLogNormalQuantiles()); + for (double score : new Random().doubles(1000L, 0.1, 3.9).toArray()) { + assertEquals(hybridModel.grade(score), deserialized.grade(score), delta); + assertEquals(hybridModel.confidence(), deserialized.confidence(), delta); + hybridModel.update(score); + deserialized.update(score); + } + } +} diff --git a/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java b/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java index 531f251ed..7d981a297 100644 --- a/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java +++ b/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java @@ -55,7 +55,6 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.ActionListener; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.MemoryTracker; import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.breaker.ADCircuitBreakerService; @@ -73,6 +72,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.monitor.jvm.JvmService; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; @@ -211,7 +211,7 @@ public void setup() { when(rcf.process(any(), anyLong())).thenReturn(descriptor); ExecutorService executorService = mock(ExecutorService.class); - when(threadPool.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); + when(threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); doAnswer(invocation -> { Runnable runnable = invocation.getArgument(0); runnable.run(); @@ -929,7 +929,7 @@ public void getEmptyStateFullSamples() { AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, AnomalyDetectorSettings.HOURLY_MAINTENANCE, threadPool, - AnomalyDetectorPlugin.AD_THREAD_POOL_NAME + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME ); CheckpointWriteWorker checkpointWriteQueue = mock(CheckpointWriteWorker.class); diff --git a/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java-e b/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java-e new file mode 100644 index 000000000..7d981a297 --- /dev/null +++ b/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java-e @@ -0,0 +1,1091 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Random; +import java.util.concurrent.ExecutorService; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import junitparams.JUnitParamsRunner; +import junitparams.Parameters; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Answers; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.ml.ModelManager.ModelType; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.ratelimit.CheckpointWriteWorker; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.monitor.jvm.JvmService; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.powermock.modules.junit4.PowerMockRunner; +import org.powermock.modules.junit4.PowerMockRunnerDelegate; + +import test.org.opensearch.ad.util.MLUtil; +import test.org.opensearch.ad.util.RandomModelStateConfig; + +import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.parkservices.AnomalyDescriptor; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; +import com.amazon.randomcutforest.returntypes.DiVector; + +@RunWith(PowerMockRunner.class) +@PowerMockRunnerDelegate(JUnitParamsRunner.class) +@SuppressWarnings("unchecked") +public class ModelManagerTests { + + private ModelManager modelManager; + + @Mock + private AnomalyDetector anomalyDetector; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private DiscoveryNodeFilterer nodeFilter; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private JvmService jvmService; + + @Mock + private CheckpointDao checkpointDao; + + @Mock + private Clock clock; + + @Mock + private FeatureManager featureManager; + + @Mock + private EntityColdStarter entityColdStarter; + + @Mock + private EntityCache cache; + + @Mock + private ModelState modelState; + + @Mock + private EntityModel entityModel; + + @Mock + private ThresholdedRandomCutForest trcf; + + private double modelDesiredSizePercentage; + private double modelMaxSizePercentage; + private int numTrees; + private int numSamples; + private int numFeatures; + private double rcfTimeDecay; + private int numMinSamples; + private double thresholdMinPvalue; + private int minPreviewSize; + private Duration modelTtl; + private Duration checkpointInterval; + private ThresholdedRandomCutForest rcf; + + @Mock + private HybridThresholdingModel hybridThresholdingModel; + + @Mock + private ThreadPool threadPool; + + private String detectorId; + private String rcfModelId; + private String thresholdModelId; + private int shingleSize; + private Settings settings; + private ClusterService clusterService; + private double[] attribution; + private double[] point; + private DiVector attributionVec; + private ClusterSettings clusterSettings; + + @Mock + private ActionListener rcfResultListener; + + @Mock + private ActionListener thresholdResultListener; + private MemoryTracker memoryTracker; + private Instant now; + + @Mock + private ADCircuitBreakerService adCircuitBreakerService; + + private String modelId = "modelId"; + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + + modelDesiredSizePercentage = 0.001; + modelMaxSizePercentage = 0.1; + numTrees = 100; + numSamples = 10; + numFeatures = 1; + rcfTimeDecay = 1.0 / 1024; + numMinSamples = 1; + thresholdMinPvalue = 0.95; + minPreviewSize = 500; + modelTtl = Duration.ofHours(1); + checkpointInterval = Duration.ofHours(1); + shingleSize = 1; + attribution = new double[] { 1, 1 }; + attributionVec = new DiVector(attribution.length); + for (int i = 0; i < attribution.length; i++) { + attributionVec.high[i] = attribution[i]; + attributionVec.low[i] = attribution[i] - 1; + } + point = new double[] { 2 }; + + rcf = spy(ThresholdedRandomCutForest.builder().dimensions(numFeatures).sampleSize(numSamples).numberOfTrees(numTrees).build()); + double score = 11.; + + double confidence = 0.091353632; + double grade = 0.1; + AnomalyDescriptor descriptor = new AnomalyDescriptor(point, 0); + descriptor.setRCFScore(score); + descriptor.setNumberOfTrees(numTrees); + descriptor.setDataConfidence(confidence); + descriptor.setAnomalyGrade(grade); + descriptor.setAttribution(attributionVec); + descriptor.setTotalUpdates(numSamples); + when(rcf.process(any(), anyLong())).thenReturn(descriptor); + + ExecutorService executorService = mock(ExecutorService.class); + when(threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(executorService).execute(any(Runnable.class)); + + now = Instant.now(); + when(clock.instant()).thenReturn(now); + + memoryTracker = mock(MemoryTracker.class); + when(memoryTracker.isHostingAllowed(anyString(), any())).thenReturn(true); + + settings = Settings + .builder() + .put("plugins.anomaly_detection.model_max_size_percent", modelMaxSizePercentage) + .put("plugins.anomaly_detection.checkpoint_saving_freq", TimeValue.timeValueHours(12)) + .build(); + + modelManager = spy( + new ModelManager( + checkpointDao, + clock, + numTrees, + numSamples, + rcfTimeDecay, + numMinSamples, + thresholdMinPvalue, + minPreviewSize, + modelTtl, + AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, + entityColdStarter, + featureManager, + memoryTracker, + settings, + null + ) + ); + + detectorId = "detectorId"; + rcfModelId = "detectorId_model_rcf_1"; + thresholdModelId = "detectorId_model_threshold"; + + when(this.modelState.getModel()).thenReturn(this.entityModel); + when(this.entityModel.getTrcf()).thenReturn(Optional.of(this.trcf)); + + when(anomalyDetector.getShingleSize()).thenReturn(shingleSize); + } + + private Object[] getDetectorIdForModelIdData() { + return new Object[] { + new Object[] { "testId_model_threshold", "testId" }, + new Object[] { "test_id_model_threshold", "test_id" }, + new Object[] { "test_model_id_model_threshold", "test_model_id" }, + new Object[] { "testId_model_rcf_1", "testId" }, + new Object[] { "test_Id_model_rcf_1", "test_Id" }, + new Object[] { "test_model_rcf_Id_model_rcf_1", "test_model_rcf_Id" }, }; + }; + + @Test + @Parameters(method = "getDetectorIdForModelIdData") + public void getDetectorIdForModelId_returnExpectedId(String modelId, String expectedDetectorId) { + assertEquals(expectedDetectorId, SingleStreamModelIdMapper.getDetectorIdForModelId(modelId)); + } + + private Object[] getDetectorIdForModelIdIllegalArgument() { + return new Object[] { new Object[] { "testId" }, new Object[] { "testid_" }, new Object[] { "_testId" }, }; + } + + @Test(expected = IllegalArgumentException.class) + @Parameters(method = "getDetectorIdForModelIdIllegalArgument") + public void getDetectorIdForModelId_throwIllegalArgument_forInvalidId(String modelId) { + SingleStreamModelIdMapper.getDetectorIdForModelId(modelId); + } + + private Map createDataNodes(int numDataNodes) { + Map dataNodes = new HashMap<>(); + for (int i = 0; i < numDataNodes; i++) { + dataNodes.put("foo" + i, mock(DiscoveryNode.class)); + } + return dataNodes; + } + + private Object[] getPartitionedForestSizesData() { + ThresholdedRandomCutForest rcf = ThresholdedRandomCutForest.builder().dimensions(1).sampleSize(10).numberOfTrees(100).build(); + return new Object[] { + // one partition given sufficient large nodes + new Object[] { rcf, 100L, 100_000L, createDataNodes(10), pair(1, 100) }, + // two paritions given sufficient medium nodes + new Object[] { rcf, 100L, 50_000L, createDataNodes(10), pair(2, 50) }, + // ten partitions given sufficent small nodes + new Object[] { rcf, 100L, 10_000L, createDataNodes(10), pair(10, 10) }, + // five double-sized paritions given fewer small nodes + new Object[] { rcf, 100L, 10_000L, createDataNodes(5), pair(5, 20) }, + // one large-sized partition given one small node + new Object[] { rcf, 100L, 1_000L, createDataNodes(1), pair(1, 100) } }; + } + + private Object[] estimateModelSizeData() { + return new Object[] { + new Object[] { ThresholdedRandomCutForest.builder().dimensions(1).sampleSize(256).numberOfTrees(100).build(), 819200L }, + new Object[] { ThresholdedRandomCutForest.builder().dimensions(5).sampleSize(256).numberOfTrees(100).build(), 4096000L } }; + } + + @Parameters(method = "estimateModelSizeData") + public void estimateModelSize_returnExpected(ThresholdedRandomCutForest rcf, long expectedSize) { + assertEquals(expectedSize, memoryTracker.estimateTRCFModelSize(rcf)); + } + + @Test + public void getRcfResult_returnExpectedToListener() { + double[] point = new double[0]; + ThresholdedRandomCutForest rForest = mock(ThresholdedRandomCutForest.class); + RandomCutForest rcf = mock(RandomCutForest.class); + when(rForest.getForest()).thenReturn(rcf); + // input length is 2 + when(rcf.getDimensions()).thenReturn(16); + when(rcf.getShingleSize()).thenReturn(8); + double score = 11.; + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(rForest)); + return null; + }).when(checkpointDao).getTRCFModel(eq(rcfModelId), any(ActionListener.class)); + + double confidence = 0.091353632; + double grade = 0.1; + int relativeIndex = 0; + double[] currentTimeAttribution = new double[] { 0.5, 0.5 }; + double[] pastalues = new double[] { 123, 456 }; + double[][] expectedValuesList = new double[][] { new double[] { 789, 12 } }; + double[] likelihood = new double[] { 1 }; + double threshold = 1.1d; + + AnomalyDescriptor descriptor = new AnomalyDescriptor(point, 0); + descriptor.setRCFScore(score); + descriptor.setNumberOfTrees(numTrees); + descriptor.setDataConfidence(confidence); + descriptor.setAnomalyGrade(grade); + descriptor.setAttribution(attributionVec); + descriptor.setTotalUpdates(numSamples); + descriptor.setRelativeIndex(relativeIndex); + descriptor.setRelevantAttribution(currentTimeAttribution); + descriptor.setPastValues(pastalues); + descriptor.setExpectedValuesList(expectedValuesList); + descriptor.setLikelihoodOfValues(likelihood); + descriptor.setThreshold(threshold); + + when(rForest.process(any(), anyLong())).thenReturn(descriptor); + + ActionListener listener = mock(ActionListener.class); + modelManager.getTRcfResult(detectorId, rcfModelId, point, listener); + + ThresholdingResult expected = new ThresholdingResult( + grade, + confidence, + score, + numSamples, + relativeIndex, + currentTimeAttribution, + pastalues, + expectedValuesList, + likelihood, + threshold, + numTrees + ); + verify(listener).onResponse(eq(expected)); + + descriptor.setTotalUpdates(numSamples + 1L); + listener = mock(ActionListener.class); + modelManager.getTRcfResult(detectorId, rcfModelId, point, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(ThresholdingResult.class); + verify(listener).onResponse(responseCaptor.capture()); + assertEquals(0.091353632, responseCaptor.getValue().getConfidence(), 1e-6); + } + + @Test + public void getRcfResult_throwToListener_whenNoCheckpoint() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.empty()); + return null; + }).when(checkpointDao).getTRCFModel(eq(rcfModelId), any(ActionListener.class)); + + ActionListener listener = mock(ActionListener.class); + modelManager.getTRcfResult(detectorId, rcfModelId, new double[0], listener); + + verify(listener).onFailure(any(ResourceNotFoundException.class)); + } + + @Test + public void getRcfResult_throwToListener_whenHeapLimitExceed() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(rcf)); + return null; + }).when(checkpointDao).getTRCFModel(eq(rcfModelId), any(ActionListener.class)); + + when(jvmService.info().getMem().getHeapMax().getBytes()).thenReturn(1_000L); + + MemoryTracker memoryTracker = new MemoryTracker( + jvmService, + modelMaxSizePercentage, + modelDesiredSizePercentage, + null, + adCircuitBreakerService + ); + + ActionListener listener = mock(ActionListener.class); + + // use new memoryTracker + modelManager = spy( + new ModelManager( + checkpointDao, + clock, + numTrees, + numSamples, + rcfTimeDecay, + numMinSamples, + thresholdMinPvalue, + minPreviewSize, + modelTtl, + AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, + entityColdStarter, + featureManager, + memoryTracker, + settings, + null + ) + ); + + modelManager.getTRcfResult(detectorId, rcfModelId, new double[0], listener); + + verify(listener).onFailure(any(LimitExceededException.class)); + } + + @Test + public void getThresholdingResult_returnExpectedToListener() { + double score = 1.; + double grade = 0.; + double confidence = 0.5; + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(hybridThresholdingModel)); + return null; + }).when(checkpointDao).getThresholdModel(eq(thresholdModelId), any(ActionListener.class)); + when(hybridThresholdingModel.grade(score)).thenReturn(grade); + when(hybridThresholdingModel.confidence()).thenReturn(confidence); + + ActionListener listener = mock(ActionListener.class); + modelManager.getThresholdingResult(detectorId, thresholdModelId, score, listener); + + ThresholdingResult expected = new ThresholdingResult(grade, confidence, score); + verify(listener).onResponse(eq(expected)); + + listener = mock(ActionListener.class); + modelManager.getThresholdingResult(detectorId, thresholdModelId, score, listener); + verify(listener).onResponse(eq(expected)); + } + + @Test + public void getThresholdingResult_throwToListener_withNoCheckpoint() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.empty()); + return null; + }).when(checkpointDao).getThresholdModel(eq(thresholdModelId), any(ActionListener.class)); + + ActionListener listener = mock(ActionListener.class); + modelManager.getThresholdingResult(detectorId, thresholdModelId, 0, listener); + + verify(listener).onFailure(any(ResourceNotFoundException.class)); + } + + @Test + public void getThresholdingResult_notUpdate_withZeroScore() { + double score = 0.0; + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(hybridThresholdingModel)); + return null; + }).when(checkpointDao).getThresholdModel(eq(thresholdModelId), any(ActionListener.class)); + + ActionListener listener = mock(ActionListener.class); + modelManager.getThresholdingResult(detectorId, thresholdModelId, score, listener); + + verify(hybridThresholdingModel, never()).update(score); + } + + @Test + public void getAllModelIds_returnAllIds_forRcfAndThreshold() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(hybridThresholdingModel)); + return null; + }).when(checkpointDao).getThresholdModel(eq(thresholdModelId), any(ActionListener.class)); + modelManager.getThresholdingResult(detectorId, thresholdModelId, 0, thresholdResultListener); + + assertEquals(Stream.of(thresholdModelId).collect(Collectors.toSet()), modelManager.getAllModelIds()); + } + + @Test + public void getAllModelIds_returnEmpty_forNoModels() { + assertEquals(Collections.emptySet(), modelManager.getAllModelIds()); + } + + @Test + public void stopModel_returnExpectedToListener_whenRcfStop() { + ThresholdedRandomCutForest forest = mock(ThresholdedRandomCutForest.class); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(forest)); + return null; + }).when(checkpointDao).getTRCFModel(eq(rcfModelId), any(ActionListener.class)); + + modelManager.getTRcfResult(detectorId, rcfModelId, new double[0], rcfResultListener); + when(clock.instant()).thenReturn(Instant.EPOCH); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(null); + return null; + }).when(checkpointDao).putTRCFCheckpoint(eq(rcfModelId), eq(forest), any(ActionListener.class)); + + ActionListener listener = mock(ActionListener.class); + modelManager.stopModel(detectorId, rcfModelId, listener); + + verify(listener).onResponse(eq(null)); + } + + @Test + public void stopModel_returnExpectedToListener_whenThresholdStop() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(hybridThresholdingModel)); + return null; + }).when(checkpointDao).getThresholdModel(eq(thresholdModelId), any(ActionListener.class)); + modelManager.getThresholdingResult(detectorId, thresholdModelId, 0, thresholdResultListener); + when(clock.instant()).thenReturn(Instant.EPOCH); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(null); + return null; + }).when(checkpointDao).putThresholdCheckpoint(eq(thresholdModelId), eq(hybridThresholdingModel), any(ActionListener.class)); + + ActionListener listener = mock(ActionListener.class); + modelManager.stopModel(detectorId, thresholdModelId, listener); + + verify(listener).onResponse(eq(null)); + } + + @Test + public void stopModel_throwToListener_whenCheckpointFail() { + ThresholdedRandomCutForest forest = mock(ThresholdedRandomCutForest.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(forest)); + return null; + }).when(checkpointDao).getTRCFModel(eq(rcfModelId), any(ActionListener.class)); + modelManager.getTRcfResult(detectorId, rcfModelId, new double[0], rcfResultListener); + when(clock.instant()).thenReturn(Instant.EPOCH); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException()); + return null; + }).when(checkpointDao).putTRCFCheckpoint(eq(rcfModelId), eq(forest), any(ActionListener.class)); + + ActionListener listener = mock(ActionListener.class); + modelManager.stopModel(detectorId, rcfModelId, listener); + + verify(listener).onFailure(any(Exception.class)); + } + + @Test + public void clear_callListener_whenRcfDeleted() { + String otherModelId = detectorId + rcfModelId; + RandomCutForest forest = mock(RandomCutForest.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(forest)); + return null; + }).when(checkpointDao).getTRCFModel(eq(rcfModelId), any(ActionListener.class)); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(forest)); + return null; + }).when(checkpointDao).getTRCFModel(eq(otherModelId), any(ActionListener.class)); + modelManager.getTRcfResult(detectorId, rcfModelId, new double[0], rcfResultListener); + modelManager.getTRcfResult(otherModelId, otherModelId, new double[0], rcfResultListener); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(checkpointDao).deleteModelCheckpoint(eq(rcfModelId), any(ActionListener.class)); + + ActionListener listener = mock(ActionListener.class); + modelManager.clear(detectorId, listener); + + verify(listener).onResponse(null); + } + + @Test + public void clear_callListener_whenThresholdDeleted() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(hybridThresholdingModel)); + return null; + }).when(checkpointDao).getThresholdModel(eq(thresholdModelId), any(ActionListener.class)); + + modelManager.getThresholdingResult(detectorId, thresholdModelId, 0, thresholdResultListener); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(checkpointDao).deleteModelCheckpoint(eq(thresholdModelId), any(ActionListener.class)); + + ActionListener listener = mock(ActionListener.class); + modelManager.clear(detectorId, listener); + + verify(listener).onResponse(null); + } + + @Test + public void clear_throwToListener_whenDeleteFail() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(rcf)); + return null; + }).when(checkpointDao).getTRCFModel(eq(rcfModelId), any(ActionListener.class)); + modelManager.getTRcfResult(detectorId, rcfModelId, new double[0], rcfResultListener); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException()); + return null; + }).when(checkpointDao).deleteModelCheckpoint(eq(rcfModelId), any(ActionListener.class)); + + ActionListener listener = mock(ActionListener.class); + modelManager.clear(detectorId, listener); + + verify(listener).onFailure(any(Exception.class)); + } + + @Test + public void trainModel_returnExpectedToListener_putCheckpoints() { + double[][] trainData = new Random().doubles().limit(100).mapToObj(d -> new double[] { d }).toArray(double[][]::new); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(null); + return null; + }).when(checkpointDao).putTRCFCheckpoint(any(), any(), any(ActionListener.class)); + + ActionListener listener = mock(ActionListener.class); + modelManager.trainModel(anomalyDetector, trainData, listener); + + verify(listener).onResponse(eq(null)); + verify(checkpointDao, times(1)).putTRCFCheckpoint(any(), any(), any()); + } + + private Object[] trainModelIllegalArgumentData() { + return new Object[] { new Object[] { new double[][] {} }, new Object[] { new double[][] { {} } } }; + } + + @Test + @Parameters(method = "trainModelIllegalArgumentData") + public void trainModel_throwIllegalArgumentToListener_forInvalidTrainData(double[][] trainData) { + ActionListener listener = mock(ActionListener.class); + modelManager.trainModel(anomalyDetector, trainData, listener); + + verify(listener).onFailure(any(IllegalArgumentException.class)); + } + + @Test + public void trainModel_throwLimitExceededToListener_whenLimitExceed() { + doThrow(new LimitExceededException(null, null)).when(checkpointDao).putTRCFCheckpoint(any(), any(), any()); + + ActionListener listener = mock(ActionListener.class); + modelManager.trainModel(anomalyDetector, new double[][] { { 0 } }, listener); + + verify(listener).onFailure(any(LimitExceededException.class)); + } + + @Test + public void getRcfModelId_returnNonEmptyString() { + String rcfModelId = SingleStreamModelIdMapper.getRcfModelId(anomalyDetector.getId(), 0); + + assertFalse(rcfModelId.isEmpty()); + } + + @Test + public void getThresholdModelId_returnNonEmptyString() { + String thresholdModelId = SingleStreamModelIdMapper.getThresholdModelId(anomalyDetector.getId()); + + assertFalse(thresholdModelId.isEmpty()); + } + + private Entry pair(int size, int value) { + return new SimpleImmutableEntry<>(size, value); + } + + @Test + public void maintenance_returnExpectedToListener_forRcfModel() { + String successModelId = "testSuccessModelId"; + String failModelId = "testFailModelId"; + double[] point = new double[0]; + ThresholdedRandomCutForest forest = mock(ThresholdedRandomCutForest.class); + ThresholdedRandomCutForest failForest = mock(ThresholdedRandomCutForest.class); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(forest)); + return null; + }).when(checkpointDao).getTRCFModel(eq(successModelId), any(ActionListener.class)); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(failForest)); + return null; + }).when(checkpointDao).getTRCFModel(eq(failModelId), any(ActionListener.class)); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(null); + return null; + }).when(checkpointDao).putTRCFCheckpoint(eq(successModelId), eq(forest), any(ActionListener.class)); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException()); + return null; + }).when(checkpointDao).putTRCFCheckpoint(eq(failModelId), eq(failForest), any(ActionListener.class)); + when(clock.instant()).thenReturn(Instant.EPOCH); + ActionListener scoreListener = mock(ActionListener.class); + modelManager.getTRcfResult(detectorId, successModelId, point, scoreListener); + modelManager.getTRcfResult(detectorId, failModelId, point, scoreListener); + + ActionListener listener = mock(ActionListener.class); + modelManager.maintenance(listener); + + verify(listener).onResponse(eq(null)); + verify(checkpointDao, times(1)).putTRCFCheckpoint(eq(successModelId), eq(forest), any(ActionListener.class)); + verify(checkpointDao, times(1)).putTRCFCheckpoint(eq(failModelId), eq(failForest), any(ActionListener.class)); + } + + @Test + public void maintenance_returnExpectedToListener_forThresholdModel() { + String successModelId = "testSuccessModelId"; + String failModelId = "testFailModelId"; + double score = 1.; + HybridThresholdingModel failThresholdModel = mock(HybridThresholdingModel.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(hybridThresholdingModel)); + return null; + }).when(checkpointDao).getThresholdModel(eq(successModelId), any(ActionListener.class)); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(failThresholdModel)); + return null; + }).when(checkpointDao).getThresholdModel(eq(failModelId), any(ActionListener.class)); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(null); + return null; + }).when(checkpointDao).putThresholdCheckpoint(eq(successModelId), eq(hybridThresholdingModel), any(ActionListener.class)); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException()); + return null; + }).when(checkpointDao).putThresholdCheckpoint(eq(failModelId), eq(failThresholdModel), any(ActionListener.class)); + when(clock.instant()).thenReturn(Instant.EPOCH); + ActionListener scoreListener = mock(ActionListener.class); + modelManager.getThresholdingResult(detectorId, successModelId, score, scoreListener); + modelManager.getThresholdingResult(detectorId, failModelId, score, scoreListener); + + ActionListener listener = mock(ActionListener.class); + modelManager.maintenance(listener); + + verify(listener).onResponse(eq(null)); + verify(checkpointDao, times(1)).putThresholdCheckpoint(eq(successModelId), eq(hybridThresholdingModel), any(ActionListener.class)); + } + + @Test + public void maintenance_returnExpectedToListener_stopModel() { + double[] point = new double[0]; + ThresholdedRandomCutForest forest = mock(ThresholdedRandomCutForest.class); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(forest)); + return null; + }).when(checkpointDao).getTRCFModel(eq(rcfModelId), any(ActionListener.class)); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(null); + return null; + }).when(checkpointDao).putTRCFCheckpoint(eq(rcfModelId), eq(forest), any(ActionListener.class)); + when(clock.instant()).thenReturn(Instant.EPOCH, Instant.EPOCH, Instant.EPOCH.plus(modelTtl.plusSeconds(1))); + ActionListener scoreListener = mock(ActionListener.class); + modelManager.getTRcfResult(detectorId, rcfModelId, point, scoreListener); + + ActionListener listener = mock(ActionListener.class); + modelManager.maintenance(listener); + verify(listener).onResponse(eq(null)); + + modelManager.getTRcfResult(detectorId, rcfModelId, point, scoreListener); + verify(checkpointDao, times(2)).getTRCFModel(eq(rcfModelId), any(ActionListener.class)); + } + + @Test + public void maintenance_returnExpectedToListener_doNothing() { + double[] point = new double[0]; + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(rcf)); + return null; + }).when(checkpointDao).getTRCFModel(eq(rcfModelId), any(ActionListener.class)); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(null); + return null; + }).when(checkpointDao).putTRCFCheckpoint(eq(rcfModelId), eq(rcf), any(ActionListener.class)); + when(clock.instant()).thenReturn(Instant.MIN); + ActionListener scoreListener = mock(ActionListener.class); + modelManager.getTRcfResult(detectorId, rcfModelId, point, scoreListener); + ActionListener listener = mock(ActionListener.class); + modelManager.maintenance(listener); + verify(listener).onResponse(eq(null)); + + listener = mock(ActionListener.class); + modelManager.maintenance(listener); + verify(listener).onResponse(eq(null)); + + modelManager.getTRcfResult(detectorId, rcfModelId, point, scoreListener); + verify(checkpointDao, times(1)).getTRCFModel(eq(rcfModelId), any(ActionListener.class)); + } + + @Test + public void getPreviewResults_returnNoAnomalies_forNoAnomalies() { + int numPoints = 1000; + double[][] points = Stream.generate(() -> new double[] { 0 }).limit(numPoints).toArray(double[][]::new); + + List results = modelManager.getPreviewResults(points, shingleSize); + + assertEquals(numPoints, results.size()); + assertTrue(results.stream().noneMatch(r -> r.getGrade() > 0)); + } + + @Test + public void getPreviewResults_returnAnomalies_forLastAnomaly() { + int numPoints = 1000; + double[][] points = Stream.generate(() -> new double[] { 0 }).limit(numPoints).toArray(double[][]::new); + points[points.length - 1] = new double[] { 1. }; + + List results = modelManager.getPreviewResults(points, shingleSize); + + assertEquals(numPoints, results.size()); + assertTrue(results.stream().limit(numPoints - 1).noneMatch(r -> r.getGrade() > 0)); + assertTrue(results.get(numPoints - 1).getGrade() > 0); + } + + @Test(expected = IllegalArgumentException.class) + public void getPreviewResults_throwIllegalArgument_forInvalidInput() { + modelManager.getPreviewResults(new double[0][0], shingleSize); + } + + @Test + public void processEmptyCheckpoint() { + ModelState modelState = modelManager.processEntityCheckpoint(Optional.empty(), null, "", "", shingleSize); + assertEquals(Instant.MIN, modelState.getLastCheckpointTime()); + } + + @Test + public void processNonEmptyCheckpoint() { + String modelId = "abc"; + String detectorId = "123"; + EntityModel model = MLUtil.createNonEmptyModel(modelId); + Instant checkpointTime = Instant.ofEpochMilli(1000); + ModelState modelState = modelManager + .processEntityCheckpoint( + Optional.of(new SimpleImmutableEntry<>(model, checkpointTime)), + null, + modelId, + detectorId, + shingleSize + ); + assertEquals(checkpointTime, modelState.getLastCheckpointTime()); + assertEquals(model.getSamples().size(), modelState.getModel().getSamples().size()); + assertEquals(now, modelState.getLastUsedTime()); + } + + @Test + public void getNullState() { + assertEquals(new ThresholdingResult(0, 0, 0), modelManager.getAnomalyResultForEntity(new double[] {}, null, "", null, shingleSize)); + } + + @Test + public void getEmptyStateFullSamples() { + SearchFeatureDao searchFeatureDao = mock(SearchFeatureDao.class); + + LinearUniformImputer interpolator = new LinearUniformImputer(true); + + NodeStateManager stateManager = mock(NodeStateManager.class); + featureManager = new FeatureManager( + searchFeatureDao, + interpolator, + clock, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, + AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, + AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, + AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, + AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, + AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME + ); + + CheckpointWriteWorker checkpointWriteQueue = mock(CheckpointWriteWorker.class); + + entityColdStarter = new EntityColdStarter( + clock, + threadPool, + stateManager, + AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + AnomalyDetectorSettings.NUM_TREES, + AnomalyDetectorSettings.TIME_DECAY, + numMinSamples, + AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, + AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, + interpolator, + searchFeatureDao, + AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + featureManager, + settings, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + AnomalyDetectorSettings.MAX_COLD_START_ROUNDS + ); + + modelManager = spy( + new ModelManager( + checkpointDao, + clock, + numTrees, + numSamples, + rcfTimeDecay, + numMinSamples, + thresholdMinPvalue, + minPreviewSize, + modelTtl, + AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, + entityColdStarter, + featureManager, + memoryTracker, + settings, + clusterService + ) + ); + + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(false).sampleSize(numMinSamples).build()); + EntityModel model = state.getModel(); + assertTrue(!model.getTrcf().isPresent()); + ThresholdingResult result = modelManager.getAnomalyResultForEntity(new double[] { -1 }, state, "", null, shingleSize); + // model outputs scores + assertTrue(result.getRcfScore() != 0); + // added the sample to score since our model is empty + assertEquals(0, model.getSamples().size()); + } + + @Test + public void getAnomalyResultForEntityNoModel() { + ModelState modelState = new ModelState<>(null, modelId, detectorId, ModelType.ENTITY.getName(), clock, 0); + ThresholdingResult result = modelManager + .getAnomalyResultForEntity( + new double[] { -1 }, + modelState, + modelId, + Entity.createSingleAttributeEntity("field", "val"), + shingleSize + ); + // model outputs scores + assertEquals(new ThresholdingResult(0, 0, 0), result); + // added the sample to score since our model is empty + assertEquals(1, modelState.getModel().getSamples().size()); + } + + @Test + public void getEmptyStateNotFullSamples() { + ModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(false).sampleSize(numMinSamples - 1).build()); + assertEquals( + new ThresholdingResult(0, 0, 0), + modelManager.getAnomalyResultForEntity(new double[] { -1 }, state, "", null, shingleSize) + ); + assertEquals(numMinSamples, state.getModel().getSamples().size()); + } + + @Test + public void scoreSamples() { + ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + modelManager.getAnomalyResultForEntity(new double[] { -1 }, state, "", null, shingleSize); + assertEquals(0, state.getModel().getSamples().size()); + assertEquals(now, state.getLastUsedTime()); + } + + public void getAnomalyResultForEntity_withTrcf() { + AnomalyDescriptor anomalyDescriptor = new AnomalyDescriptor(point, 0); + anomalyDescriptor.setRCFScore(2); + anomalyDescriptor.setDataConfidence(1); + anomalyDescriptor.setAnomalyGrade(1); + when(this.trcf.process(this.point, 0)).thenReturn(anomalyDescriptor); + + ThresholdingResult result = modelManager + .getAnomalyResultForEntity(this.point, this.modelState, this.detectorId, null, this.shingleSize); + assertEquals( + new ThresholdingResult( + anomalyDescriptor.getAnomalyGrade(), + anomalyDescriptor.getDataConfidence(), + anomalyDescriptor.getRCFScore() + ), + result + ); + } + + @Test + public void score_with_trcf() { + AnomalyDescriptor anomalyDescriptor = new AnomalyDescriptor(point, 0); + anomalyDescriptor.setRCFScore(2); + anomalyDescriptor.setAnomalyGrade(1); + // input dimension is 5 + anomalyDescriptor.setRelevantAttribution(new double[] { 0, 0, 0, 0, 0 }); + RandomCutForest rcf = mock(RandomCutForest.class); + when(rcf.getShingleSize()).thenReturn(8); + when(rcf.getDimensions()).thenReturn(40); + when(this.trcf.getForest()).thenReturn(rcf); + when(this.trcf.process(this.point, 0)).thenReturn(anomalyDescriptor); + when(this.entityModel.getSamples()).thenReturn(new ArrayDeque<>(Arrays.asList(this.point))); + + ThresholdingResult result = modelManager.score(this.point, this.detectorId, this.modelState); + assertEquals( + new ThresholdingResult( + anomalyDescriptor.getAnomalyGrade(), + anomalyDescriptor.getDataConfidence(), + anomalyDescriptor.getRCFScore(), + 0, + 0, + anomalyDescriptor.getRelevantAttribution(), + null, + null, + null, + 0, + numTrees + ), + result + ); + } + + @Test(expected = IllegalArgumentException.class) + public void score_throw() { + AnomalyDescriptor anomalyDescriptor = new AnomalyDescriptor(point, 0); + anomalyDescriptor.setRCFScore(2); + anomalyDescriptor.setAnomalyGrade(1); + // input dimension is 5 + anomalyDescriptor.setRelevantAttribution(new double[] { 0, 0, 0, 0, 0 }); + RandomCutForest rcf = mock(RandomCutForest.class); + when(rcf.getShingleSize()).thenReturn(8); + when(rcf.getDimensions()).thenReturn(40); + when(this.trcf.getForest()).thenReturn(rcf); + doThrow(new IllegalArgumentException()).when(trcf).process(any(), anyLong()); + when(this.entityModel.getSamples()).thenReturn(new ArrayDeque<>(Arrays.asList(this.point))); + modelManager.score(this.point, this.detectorId, this.modelState); + } +} diff --git a/src/test/java/org/opensearch/ad/ml/SingleStreamModelIdMapperTests.java-e b/src/test/java/org/opensearch/ad/ml/SingleStreamModelIdMapperTests.java-e new file mode 100644 index 000000000..59a0d02da --- /dev/null +++ b/src/test/java/org/opensearch/ad/ml/SingleStreamModelIdMapperTests.java-e @@ -0,0 +1,24 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import org.opensearch.test.OpenSearchTestCase; + +public class SingleStreamModelIdMapperTests extends OpenSearchTestCase { + public void testGetThresholdModelIdFromRCFModelId() { + assertEquals( + "Y62IGnwBFHAk-4HQQeoo_model_threshold", + SingleStreamModelIdMapper.getThresholdModelIdFromRCFModelId("Y62IGnwBFHAk-4HQQeoo_model_rcf_1") + ); + } + +} diff --git a/src/test/java/org/opensearch/ad/ml/ThresholdingResultTests.java-e b/src/test/java/org/opensearch/ad/ml/ThresholdingResultTests.java-e new file mode 100644 index 000000000..492bbec45 --- /dev/null +++ b/src/test/java/org/opensearch/ad/ml/ThresholdingResultTests.java-e @@ -0,0 +1,71 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ml; + +import static org.junit.Assert.assertEquals; + +import junitparams.JUnitParamsRunner; +import junitparams.Parameters; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.opensearch.ad.model.AnomalyResult; + +@RunWith(JUnitParamsRunner.class) +public class ThresholdingResultTests { + + private double grade = 1.; + private double confidence = 0.5; + double score = 1.; + + private ThresholdingResult thresholdingResult = new ThresholdingResult(grade, confidence, score); + + @Test + public void getters_returnExcepted() { + assertEquals(grade, thresholdingResult.getGrade(), 1e-8); + assertEquals(confidence, thresholdingResult.getConfidence(), 1e-8); + } + + private Object[] equalsData() { + return new Object[] { + new Object[] { thresholdingResult, thresholdingResult, true }, + new Object[] { thresholdingResult, null, false }, + new Object[] { thresholdingResult, AnomalyResult.getDummyResult(), false }, + new Object[] { thresholdingResult, null, false }, + new Object[] { thresholdingResult, thresholdingResult, true }, + new Object[] { thresholdingResult, 1, false }, + new Object[] { thresholdingResult, new ThresholdingResult(grade, confidence, score), true }, + new Object[] { thresholdingResult, new ThresholdingResult(grade + 1, confidence, score), false }, + new Object[] { thresholdingResult, new ThresholdingResult(grade, confidence + 1, score), false }, + new Object[] { thresholdingResult, new ThresholdingResult(grade + 1, confidence + 1, score), false }, }; + } + + @Test + @Parameters(method = "equalsData") + public void equals_returnExpected(ThresholdingResult result, Object other, boolean expected) { + assertEquals(expected, result.equals(other)); + } + + private Object[] hashCodeData() { + return new Object[] { + new Object[] { thresholdingResult, new ThresholdingResult(grade, confidence, score), true }, + new Object[] { thresholdingResult, new ThresholdingResult(grade + 1, confidence, score), false }, + new Object[] { thresholdingResult, new ThresholdingResult(grade, confidence + 1, score), false }, + new Object[] { thresholdingResult, new ThresholdingResult(grade + 1, confidence + 1, score), false }, }; + } + + @Test + @Parameters(method = "hashCodeData") + public void hashCode_returnExpected(ThresholdingResult result, ThresholdingResult other, boolean expected) { + assertEquals(expected, result.hashCode() == other.hashCode()); + } +} diff --git a/src/test/java/org/opensearch/ad/mock/model/MockSimpleLog.java b/src/test/java/org/opensearch/ad/mock/model/MockSimpleLog.java index a6716b719..e97b77fce 100644 --- a/src/test/java/org/opensearch/ad/mock/model/MockSimpleLog.java +++ b/src/test/java/org/opensearch/ad/mock/model/MockSimpleLog.java @@ -14,8 +14,8 @@ import java.io.IOException; import java.time.Instant; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/test/java/org/opensearch/ad/mock/model/MockSimpleLog.java-e b/src/test/java/org/opensearch/ad/mock/model/MockSimpleLog.java-e new file mode 100644 index 000000000..e97b77fce --- /dev/null +++ b/src/test/java/org/opensearch/ad/mock/model/MockSimpleLog.java-e @@ -0,0 +1,148 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.mock.model; + +import java.io.IOException; +import java.time.Instant; + +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +public class MockSimpleLog implements ToXContentObject, Writeable { + + public static final String TIME_FIELD = "timestamp"; + public static final String VALUE_FIELD = "value"; + public static final String IP_FIELD = "ip"; + public static final String CATEGORY_FIELD = "category"; + public static final String IS_ERROR_FIELD = "is_error"; + public static final String MESSAGE_FIELD = "message"; + + public static final String INDEX_MAPPING = "{\"mappings\":{\"properties\":{" + + "\"" + + TIME_FIELD + + "\":{\"type\":\"date\",\"format\":\"strict_date_time||epoch_millis\"}," + + "\"" + + VALUE_FIELD + + "\":{\"type\":\"double\"}," + + "\"" + + IP_FIELD + + "\":{\"type\":\"ip\"}," + + "\"" + + CATEGORY_FIELD + + "\":{\"type\":\"keyword\"}," + + "\"" + + IS_ERROR_FIELD + + "\":{\"type\":\"boolean\"}," + + "\"" + + MESSAGE_FIELD + + "\":{\"type\":\"text\"}}}}"; + + private Instant timestamp; + private Double value; + private String ip; + private String category; + private Boolean isError; + private String message; + + public MockSimpleLog(Instant timestamp, Double value, String ip, String category, Boolean isError, String message) { + this.timestamp = timestamp; + this.value = value; + this.ip = ip; + this.category = category; + this.isError = isError; + this.message = message; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalInstant(timestamp); + out.writeOptionalDouble(value); + out.writeOptionalString(ip); + out.writeOptionalString(category); + out.writeOptionalBoolean(isError); + out.writeOptionalString(message); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + if (timestamp != null) { + xContentBuilder.field(TIME_FIELD, timestamp.toEpochMilli()); + } + if (value != null) { + xContentBuilder.field(VALUE_FIELD, value); + } + if (ip != null) { + xContentBuilder.field(IP_FIELD, ip); + } + if (category != null) { + xContentBuilder.field(CATEGORY_FIELD, category); + } + if (isError != null) { + xContentBuilder.field(IS_ERROR_FIELD, isError); + } + if (message != null) { + xContentBuilder.field(MESSAGE_FIELD, message); + } + return xContentBuilder.endObject(); + } + + public Instant getTimestamp() { + return timestamp; + } + + public void setTimestamp(Instant timestamp) { + this.timestamp = timestamp; + } + + public Double getValue() { + return value; + } + + public void setValue(Double value) { + this.value = value; + } + + public String getIp() { + return ip; + } + + public void setIp(String ip) { + this.ip = ip; + } + + public String getCategory() { + return category; + } + + public void setCategory(String category) { + this.category = category; + } + + public Boolean getError() { + return isError; + } + + public void setError(Boolean error) { + isError = error; + } + + public String getMessage() { + return message; + } + + public void setMessage(String message) { + this.message = message; + } +} diff --git a/src/test/java/org/opensearch/ad/mock/plugin/MockReindexPlugin.java-e b/src/test/java/org/opensearch/ad/mock/plugin/MockReindexPlugin.java-e new file mode 100644 index 000000000..29db051e0 --- /dev/null +++ b/src/test/java/org/opensearch/ad/mock/plugin/MockReindexPlugin.java-e @@ -0,0 +1,158 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.mock.plugin; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionResponse; +import org.opensearch.action.bulk.BulkAction; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkRequestBuilder; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.mock.transport.MockAnomalyDetectorJobAction; +import org.opensearch.ad.mock.transport.MockAnomalyDetectorJobTransportActionWithUser; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.reindex.BulkByScrollResponse; +import org.opensearch.index.reindex.BulkByScrollTask; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.index.reindex.UpdateByQueryAction; +import org.opensearch.index.reindex.UpdateByQueryRequest; +import org.opensearch.plugins.ActionPlugin; +import org.opensearch.plugins.Plugin; +import org.opensearch.search.SearchHit; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableList; + +public class MockReindexPlugin extends Plugin implements ActionPlugin { + + @Override + public List> getActions() { + return Arrays + .asList( + new ActionHandler<>(UpdateByQueryAction.INSTANCE, MockTransportUpdateByQueryAction.class), + new ActionHandler<>(DeleteByQueryAction.INSTANCE, MockTransportDeleteByQueryAction.class), + new ActionHandler<>(MockAnomalyDetectorJobAction.INSTANCE, MockAnomalyDetectorJobTransportActionWithUser.class) + ); + } + + public static class MockTransportUpdateByQueryAction extends HandledTransportAction { + + @Inject + public MockTransportUpdateByQueryAction(ActionFilters actionFilters, TransportService transportService) { + super(UpdateByQueryAction.NAME, transportService, actionFilters, UpdateByQueryRequest::new); + } + + @Override + protected void doExecute(Task task, UpdateByQueryRequest request, ActionListener listener) { + BulkByScrollResponse response = null; + try { + XContentParser parser = TestHelpers + .parser( + "{\"slice_id\":1,\"total\":2,\"updated\":3,\"created\":0,\"deleted\":0,\"batches\":6," + + "\"version_conflicts\":0,\"noops\":0,\"retries\":{\"bulk\":0,\"search\":10}," + + "\"throttled_millis\":0,\"requests_per_second\":13.0,\"canceled\":\"reasonCancelled\"," + + "\"throttled_until_millis\":14}" + ); + parser.nextToken(); + response = new BulkByScrollResponse( + TimeValue.timeValueMillis(10), + BulkByScrollTask.Status.innerFromXContent(parser), + ImmutableList.of(), + ImmutableList.of(), + false + ); + } catch (IOException exception) { + logger.error(exception); + } + listener.onResponse(response); + } + } + + public static class MockTransportDeleteByQueryAction extends HandledTransportAction { + + private Client client; + + @Inject + public MockTransportDeleteByQueryAction(ActionFilters actionFilters, TransportService transportService, Client client) { + super(DeleteByQueryAction.NAME, transportService, actionFilters, DeleteByQueryRequest::new); + this.client = client; + } + + @Override + protected void doExecute(Task task, DeleteByQueryRequest request, ActionListener listener) { + try { + SearchRequest searchRequest = request.getSearchRequest(); + client.search(searchRequest, ActionListener.wrap(r -> { + long totalHits = r.getHits().getTotalHits().value; + Iterator iterator = r.getHits().iterator(); + BulkRequestBuilder bulkRequestBuilder = client.prepareBulk(); + while (iterator.hasNext()) { + String id = iterator.next().getId(); + DeleteRequest deleteRequest = new DeleteRequest(ADCommonName.DETECTION_STATE_INDEX, id); + bulkRequestBuilder.add(deleteRequest); + } + BulkRequest bulkRequest = bulkRequestBuilder.request().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client + .execute( + BulkAction.INSTANCE, + bulkRequest, + ActionListener + .wrap( + res -> { listener.onResponse(mockBulkByScrollResponse(totalHits)); }, + ex -> { listener.onFailure(ex); } + ) + ); + + }, e -> { listener.onFailure(e); })); + } catch (Exception e) { + listener.onFailure(e); + } + } + + private BulkByScrollResponse mockBulkByScrollResponse(long totalHits) throws IOException { + XContentParser parser = TestHelpers + .parser( + "{\"slice_id\":1,\"total\":2,\"updated\":0,\"created\":0,\"deleted\":" + + totalHits + + ",\"batches\":6,\"version_conflicts\":0,\"noops\":0,\"retries\":{\"bulk\":0," + + "\"search\":10},\"throttled_millis\":0,\"requests_per_second\":13.0,\"canceled\":" + + "\"reasonCancelled\",\"throttled_until_millis\":14}" + ); + parser.nextToken(); + BulkByScrollResponse response = new BulkByScrollResponse( + TimeValue.timeValueMillis(10), + BulkByScrollTask.Status.innerFromXContent(parser), + ImmutableList.of(), + ImmutableList.of(), + false + ); + return response; + } + } +} diff --git a/src/test/java/org/opensearch/ad/mock/transport/MockADCancelTaskNodeRequest_1_0.java b/src/test/java/org/opensearch/ad/mock/transport/MockADCancelTaskNodeRequest_1_0.java index fff9aa524..266ec5bd0 100644 --- a/src/test/java/org/opensearch/ad/mock/transport/MockADCancelTaskNodeRequest_1_0.java +++ b/src/test/java/org/opensearch/ad/mock/transport/MockADCancelTaskNodeRequest_1_0.java @@ -13,8 +13,8 @@ import java.io.IOException; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.transport.TransportRequest; public class MockADCancelTaskNodeRequest_1_0 extends TransportRequest { diff --git a/src/test/java/org/opensearch/ad/mock/transport/MockADCancelTaskNodeRequest_1_0.java-e b/src/test/java/org/opensearch/ad/mock/transport/MockADCancelTaskNodeRequest_1_0.java-e new file mode 100644 index 000000000..266ec5bd0 --- /dev/null +++ b/src/test/java/org/opensearch/ad/mock/transport/MockADCancelTaskNodeRequest_1_0.java-e @@ -0,0 +1,50 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.mock.transport; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + +public class MockADCancelTaskNodeRequest_1_0 extends TransportRequest { + private String detectorId; + private String userName; + + public MockADCancelTaskNodeRequest_1_0(StreamInput in) throws IOException { + super(in); + this.detectorId = in.readOptionalString(); + this.userName = in.readOptionalString(); + } + + public MockADCancelTaskNodeRequest_1_0(String detectorId, String userName) { + this.detectorId = detectorId; + this.userName = userName; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeOptionalString(detectorId); + out.writeOptionalString(userName); + } + + public String getId() { + return detectorId; + } + + public String getUserName() { + return userName; + } + +} diff --git a/src/test/java/org/opensearch/ad/mock/transport/MockADTaskAction_1_0.java-e b/src/test/java/org/opensearch/ad/mock/transport/MockADTaskAction_1_0.java-e new file mode 100644 index 000000000..b52eb5339 --- /dev/null +++ b/src/test/java/org/opensearch/ad/mock/transport/MockADTaskAction_1_0.java-e @@ -0,0 +1,17 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.mock.transport; + +public enum MockADTaskAction_1_0 { + START, + STOP +} diff --git a/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobAction.java-e b/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobAction.java-e new file mode 100644 index 000000000..327e3bf51 --- /dev/null +++ b/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobAction.java-e @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.mock.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.transport.AnomalyDetectorJobResponse; + +public class MockAnomalyDetectorJobAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/mockjobmanagement"; + public static final MockAnomalyDetectorJobAction INSTANCE = new MockAnomalyDetectorJobAction(); + + private MockAnomalyDetectorJobAction() { + super(NAME, AnomalyDetectorJobResponse::new); + } + +} diff --git a/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobTransportActionWithUser.java-e b/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobTransportActionWithUser.java-e new file mode 100644 index 000000000..15d37c89d --- /dev/null +++ b/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobTransportActionWithUser.java-e @@ -0,0 +1,155 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.mock.transport; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.REQUEST_TIMEOUT; +import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.ad.ExecuteADResultResponseRecorder; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.rest.handler.IndexAnomalyDetectorJobActionHandler; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.AnomalyDetectorJobRequest; +import org.opensearch.ad.transport.AnomalyDetectorJobResponse; +import org.opensearch.ad.transport.AnomalyDetectorJobTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +public class MockAnomalyDetectorJobTransportActionWithUser extends + HandledTransportAction { + private final Logger logger = LogManager.getLogger(AnomalyDetectorJobTransportAction.class); + + private final Client client; + private final ClusterService clusterService; + private final Settings settings; + private final ADIndexManagement anomalyDetectionIndices; + private final NamedXContentRegistry xContentRegistry; + private volatile Boolean filterByEnabled; + private ThreadContext.StoredContext context; + private final ADTaskManager adTaskManager; + private final TransportService transportService; + private final ExecuteADResultResponseRecorder recorder; + + @Inject + public MockAnomalyDetectorJobTransportActionWithUser( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ClusterService clusterService, + Settings settings, + ADIndexManagement anomalyDetectionIndices, + NamedXContentRegistry xContentRegistry, + ADTaskManager adTaskManager, + ExecuteADResultResponseRecorder recorder + ) { + super(MockAnomalyDetectorJobAction.NAME, transportService, actionFilters, AnomalyDetectorJobRequest::new); + this.transportService = transportService; + this.client = client; + this.clusterService = clusterService; + this.settings = settings; + this.anomalyDetectionIndices = anomalyDetectionIndices; + this.xContentRegistry = xContentRegistry; + this.adTaskManager = adTaskManager; + filterByEnabled = FILTER_BY_BACKEND_ROLES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); + + ThreadContext threadContext = new ThreadContext(settings); + context = threadContext.stashContext(); + this.recorder = recorder; + } + + @Override + protected void doExecute(Task task, AnomalyDetectorJobRequest request, ActionListener listener) { + String detectorId = request.getDetectorID(); + DateRange detectionDateRange = request.getDetectionDateRange(); + boolean historical = request.isHistorical(); + long seqNo = request.getSeqNo(); + long primaryTerm = request.getPrimaryTerm(); + String rawPath = request.getRawPath(); + TimeValue requestTimeout = REQUEST_TIMEOUT.get(settings); + String userStr = "user_name|backendrole1,backendrole2|roles1,role2"; + // By the time request reaches here, the user permissions are validated by Security plugin. + User user = User.parse(userStr); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + resolveUserAndExecute( + user, + detectorId, + filterByEnabled, + listener, + (anomalyDetector) -> executeDetector( + listener, + detectorId, + seqNo, + primaryTerm, + rawPath, + requestTimeout, + user, + detectionDateRange, + historical + ), + client, + clusterService, + xContentRegistry + ); + } catch (Exception e) { + logger.error(e); + listener.onFailure(e); + } + } + + private void executeDetector( + ActionListener listener, + String detectorId, + long seqNo, + long primaryTerm, + String rawPath, + TimeValue requestTimeout, + User user, + DateRange detectionDateRange, + boolean historical + ) { + IndexAnomalyDetectorJobActionHandler handler = new IndexAnomalyDetectorJobActionHandler( + client, + anomalyDetectionIndices, + detectorId, + seqNo, + primaryTerm, + requestTimeout, + xContentRegistry, + transportService, + adTaskManager, + recorder + ); + if (rawPath.endsWith(RestHandlerUtils.START_JOB)) { + adTaskManager.startDetector(detectorId, detectionDateRange, handler, user, transportService, context, listener); + } else if (rawPath.endsWith(RestHandlerUtils.STOP_JOB)) { + // Stop detector + adTaskManager.stopDetector(detectorId, historical, handler, user, transportService, listener); + } + } +} diff --git a/src/test/java/org/opensearch/ad/mock/transport/MockForwardADTaskRequest_1_0.java b/src/test/java/org/opensearch/ad/mock/transport/MockForwardADTaskRequest_1_0.java index 610fbb1fd..8b4f5e0d3 100644 --- a/src/test/java/org/opensearch/ad/mock/transport/MockForwardADTaskRequest_1_0.java +++ b/src/test/java/org/opensearch/ad/mock/transport/MockForwardADTaskRequest_1_0.java @@ -19,9 +19,9 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; public class MockForwardADTaskRequest_1_0 extends ActionRequest { private AnomalyDetector detector; diff --git a/src/test/java/org/opensearch/ad/mock/transport/MockForwardADTaskRequest_1_0.java-e b/src/test/java/org/opensearch/ad/mock/transport/MockForwardADTaskRequest_1_0.java-e new file mode 100644 index 000000000..5cb235d58 --- /dev/null +++ b/src/test/java/org/opensearch/ad/mock/transport/MockForwardADTaskRequest_1_0.java-e @@ -0,0 +1,84 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.mock.transport; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.commons.authuser.User; + +public class MockForwardADTaskRequest_1_0 extends ActionRequest { + private AnomalyDetector detector; + private User user; + private MockADTaskAction_1_0 adTaskAction; + + public MockForwardADTaskRequest_1_0(AnomalyDetector detector, User user, MockADTaskAction_1_0 adTaskAction) { + this.detector = detector; + this.user = user; + this.adTaskAction = adTaskAction; + } + + public MockForwardADTaskRequest_1_0(StreamInput in) throws IOException { + super(in); + this.detector = new AnomalyDetector(in); + if (in.readBoolean()) { + this.user = new User(in); + } + this.adTaskAction = in.readEnum(MockADTaskAction_1_0.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + detector.writeTo(out); + if (user != null) { + out.writeBoolean(true); + user.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeEnum(adTaskAction); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (detector == null) { + validationException = addValidationError(ADCommonMessages.DETECTOR_MISSING, validationException); + } else if (detector.getId() == null) { + validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + } + if (adTaskAction == null) { + validationException = addValidationError(ADCommonMessages.AD_TASK_ACTION_MISSING, validationException); + } + return validationException; + } + + public AnomalyDetector getDetector() { + return detector; + } + + public User getUser() { + return user; + } + + public MockADTaskAction_1_0 getAdTaskAction() { + return adTaskAction; + } +} diff --git a/src/test/java/org/opensearch/ad/model/ADEntityTaskProfileTests.java b/src/test/java/org/opensearch/ad/model/ADEntityTaskProfileTests.java index daf8812f4..27456589a 100644 --- a/src/test/java/org/opensearch/ad/model/ADEntityTaskProfileTests.java +++ b/src/test/java/org/opensearch/ad/model/ADEntityTaskProfileTests.java @@ -9,22 +9,22 @@ import java.util.Collection; import java.util.TreeMap; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; -import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.plugins.Plugin; import org.opensearch.test.InternalSettingsPlugin; import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.model.Entity; public class ADEntityTaskProfileTests extends OpenSearchSingleNodeTestCase { @Override protected Collection> getPlugins() { - return pluginList(InternalSettingsPlugin.class, AnomalyDetectorPlugin.class); + return pluginList(InternalSettingsPlugin.class, TimeSeriesAnalyticsPlugin.class); } @Override diff --git a/src/test/java/org/opensearch/ad/model/ADEntityTaskProfileTests.java-e b/src/test/java/org/opensearch/ad/model/ADEntityTaskProfileTests.java-e new file mode 100644 index 000000000..27456589a --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/ADEntityTaskProfileTests.java-e @@ -0,0 +1,125 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.model; + +import java.io.IOException; +import java.util.Collection; +import java.util.TreeMap; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.plugins.Plugin; +import org.opensearch.test.InternalSettingsPlugin; +import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.model.Entity; + +public class ADEntityTaskProfileTests extends OpenSearchSingleNodeTestCase { + + @Override + protected Collection> getPlugins() { + return pluginList(InternalSettingsPlugin.class, TimeSeriesAnalyticsPlugin.class); + } + + @Override + protected NamedWriteableRegistry writableRegistry() { + return getInstanceFromNode(NamedWriteableRegistry.class); + } + + private ADEntityTaskProfile createADEntityTaskProfile() { + Entity entity = createEntityAndAttributes(); + return new ADEntityTaskProfile(1, 23L, false, 1, 2L, "1234", entity, "4321", ADTaskType.HISTORICAL_HC_ENTITY.name()); + } + + private Entity createEntityAndAttributes() { + TreeMap attributes = new TreeMap<>(); + String name1 = "host"; + String val1 = "server_2"; + String name2 = "service"; + String val2 = "app_4"; + attributes.put(name1, val1); + attributes.put(name2, val2); + return Entity.createEntityFromOrderedMap(attributes); + } + + public void testADEntityTaskProfileSerialization() throws IOException { + ADEntityTaskProfile entityTask = createADEntityTaskProfile(); + BytesStreamOutput output = new BytesStreamOutput(); + entityTask.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + ADEntityTaskProfile parsedEntityTask = new ADEntityTaskProfile(input); + assertEquals(entityTask, parsedEntityTask); + } + + public void testParseADEntityTaskProfile() throws IOException { + ADEntityTaskProfile entityTask = createADEntityTaskProfile(); + String adEntityTaskProfileString = TestHelpers + .xContentBuilderToString(entityTask.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + ADEntityTaskProfile parsedEntityTask = ADEntityTaskProfile.parse(TestHelpers.parser(adEntityTaskProfileString)); + assertEquals(entityTask, parsedEntityTask); + } + + public void testParseADEntityTaskProfileWithNullEntity() throws IOException { + ADEntityTaskProfile entityTask = new ADEntityTaskProfile( + 1, + 23L, + false, + 1, + 2L, + "1234", + null, + "4321", + ADTaskType.HISTORICAL_HC_ENTITY.name() + ); + assertEquals(Integer.valueOf(1), entityTask.getShingleSize()); + assertEquals(23L, (long) entityTask.getRcfTotalUpdates()); + assertNull(entityTask.getEntity()); + String adEntityTaskProfileString = TestHelpers + .xContentBuilderToString(entityTask.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + ADEntityTaskProfile parsedEntityTask = ADEntityTaskProfile.parse(TestHelpers.parser(adEntityTaskProfileString)); + assertEquals(entityTask, parsedEntityTask); + } + + public void testADEntityTaskProfileEqual() { + ADEntityTaskProfile entityTaskOne = createADEntityTaskProfile(); + ADEntityTaskProfile entityTaskTwo = createADEntityTaskProfile(); + ADEntityTaskProfile entityTaskThree = new ADEntityTaskProfile( + null, + null, + false, + 1, + null, + "1234", + null, + "4321", + ADTaskType.HISTORICAL_HC_ENTITY.name() + ); + assertTrue(entityTaskOne.equals(entityTaskTwo)); + assertFalse(entityTaskOne.equals(entityTaskThree)); + } + + public void testParseADEntityTaskProfileWithMultipleNullFields() throws IOException { + Entity entity = createEntityAndAttributes(); + ADEntityTaskProfile entityTask = new ADEntityTaskProfile( + null, + null, + false, + 1, + null, + "1234", + entity, + "4321", + ADTaskType.HISTORICAL_HC_ENTITY.name() + ); + String adEntityTaskProfileString = TestHelpers + .xContentBuilderToString(entityTask.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + ADEntityTaskProfile parsedEntityTask = ADEntityTaskProfile.parse(TestHelpers.parser(adEntityTaskProfileString)); + assertEquals(entityTask, parsedEntityTask); + } +} diff --git a/src/test/java/org/opensearch/ad/model/ADTaskTests.java b/src/test/java/org/opensearch/ad/model/ADTaskTests.java index 546a09f52..1cd2e6cc8 100644 --- a/src/test/java/org/opensearch/ad/model/ADTaskTests.java +++ b/src/test/java/org/opensearch/ad/model/ADTaskTests.java @@ -16,21 +16,21 @@ import java.time.temporal.ChronoUnit; import java.util.Collection; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; -import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.plugins.Plugin; import org.opensearch.test.InternalSettingsPlugin; import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; public class ADTaskTests extends OpenSearchSingleNodeTestCase { @Override protected Collection> getPlugins() { - return pluginList(InternalSettingsPlugin.class, AnomalyDetectorPlugin.class); + return pluginList(InternalSettingsPlugin.class, TimeSeriesAnalyticsPlugin.class); } @Override diff --git a/src/test/java/org/opensearch/ad/model/ADTaskTests.java-e b/src/test/java/org/opensearch/ad/model/ADTaskTests.java-e new file mode 100644 index 000000000..1cd2e6cc8 --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/ADTaskTests.java-e @@ -0,0 +1,94 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collection; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.plugins.Plugin; +import org.opensearch.test.InternalSettingsPlugin; +import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +public class ADTaskTests extends OpenSearchSingleNodeTestCase { + + @Override + protected Collection> getPlugins() { + return pluginList(InternalSettingsPlugin.class, TimeSeriesAnalyticsPlugin.class); + } + + @Override + protected NamedWriteableRegistry writableRegistry() { + return getInstanceFromNode(NamedWriteableRegistry.class); + } + + public void testAdTaskSerialization() throws IOException { + ADTask adTask = TestHelpers.randomAdTask(randomAlphaOfLength(5), ADTaskState.STOPPED, Instant.now(), randomAlphaOfLength(5), true); + BytesStreamOutput output = new BytesStreamOutput(); + adTask.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + ADTask parsedADTask = new ADTask(input); + assertEquals("AD task serialization doesn't work", adTask, parsedADTask); + } + + public void testAdTaskSerializationWithNullDetector() throws IOException { + ADTask adTask = TestHelpers.randomAdTask(randomAlphaOfLength(5), ADTaskState.STOPPED, Instant.now(), randomAlphaOfLength(5), false); + BytesStreamOutput output = new BytesStreamOutput(); + adTask.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + ADTask parsedADTask = new ADTask(input); + assertEquals("AD task serialization doesn't work", adTask, parsedADTask); + } + + public void testParseADTask() throws IOException { + ADTask adTask = TestHelpers + .randomAdTask(null, ADTaskState.STOPPED, Instant.now().truncatedTo(ChronoUnit.SECONDS), randomAlphaOfLength(5), true); + String taskId = randomAlphaOfLength(5); + adTask.setTaskId(taskId); + String adTaskString = TestHelpers.xContentBuilderToString(adTask.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + ADTask parsedADTask = ADTask.parse(TestHelpers.parser(adTaskString), adTask.getTaskId()); + assertEquals("Parsing AD task doesn't work", adTask, parsedADTask); + } + + public void testParseADTaskWithoutTaskId() throws IOException { + String taskId = null; + ADTask adTask = TestHelpers + .randomAdTask(taskId, ADTaskState.STOPPED, Instant.now().truncatedTo(ChronoUnit.SECONDS), randomAlphaOfLength(5), true); + String adTaskString = TestHelpers.xContentBuilderToString(adTask.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + ADTask parsedADTask = ADTask.parse(TestHelpers.parser(adTaskString)); + assertEquals("Parsing AD task doesn't work", adTask, parsedADTask); + } + + public void testParseADTaskWithNullDetector() throws IOException { + String taskId = randomAlphaOfLength(5); + ADTask adTask = TestHelpers + .randomAdTask(taskId, ADTaskState.STOPPED, Instant.now().truncatedTo(ChronoUnit.SECONDS), randomAlphaOfLength(5), false); + String adTaskString = TestHelpers.xContentBuilderToString(adTask.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + ADTask parsedADTask = ADTask.parse(TestHelpers.parser(adTaskString), taskId); + assertEquals("Parsing AD task doesn't work", adTask, parsedADTask); + } + + public void testParseNullableFields() throws IOException { + ADTask adTask = ADTask.builder().build(); + String adTaskString = TestHelpers.xContentBuilderToString(adTask.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + ADTask parsedADTask = ADTask.parse(TestHelpers.parser(adTaskString)); + assertEquals("Parsing AD task doesn't work", adTask, parsedADTask); + } + +} diff --git a/src/test/java/org/opensearch/ad/model/AnomalyDetectorExecutionInputTests.java-e b/src/test/java/org/opensearch/ad/model/AnomalyDetectorExecutionInputTests.java-e new file mode 100644 index 000000000..d383aed3d --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/AnomalyDetectorExecutionInputTests.java-e @@ -0,0 +1,64 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Locale; + +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.TestHelpers; + +public class AnomalyDetectorExecutionInputTests extends OpenSearchTestCase { + + public void testParseAnomalyDetectorExecutionInput() throws IOException { + AnomalyDetectorExecutionInput detectorExecutionInput = TestHelpers.randomAnomalyDetectorExecutionInput(); + String detectInputString = TestHelpers + .xContentBuilderToString(detectorExecutionInput.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + detectInputString = detectInputString + .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); + AnomalyDetectorExecutionInput parsedAnomalyDetectorExecutionInput = AnomalyDetectorExecutionInput + .parse(TestHelpers.parser(detectInputString), detectorExecutionInput.getDetectorId()); + assertEquals("Parsing anomaly detect execution input doesn't work", detectorExecutionInput, parsedAnomalyDetectorExecutionInput); + } + + public void testNullPeriodStart() throws Exception { + TestHelpers + .assertFailWith( + IllegalArgumentException.class, + () -> new AnomalyDetectorExecutionInput(randomAlphaOfLength(5), null, Instant.now(), null) + ); + } + + public void testNullPeriodEnd() throws Exception { + TestHelpers + .assertFailWith( + IllegalArgumentException.class, + () -> new AnomalyDetectorExecutionInput(randomAlphaOfLength(5), Instant.now(), null, null) + ); + } + + public void testWrongPeriod() throws Exception { + TestHelpers + .assertFailWith( + IllegalArgumentException.class, + () -> new AnomalyDetectorExecutionInput( + randomAlphaOfLength(5), + Instant.now(), + Instant.now().minus(5, ChronoUnit.MINUTES), + null + ) + ); + } +} diff --git a/src/test/java/org/opensearch/ad/model/AnomalyDetectorJobTests.java b/src/test/java/org/opensearch/ad/model/AnomalyDetectorJobTests.java index bb165e665..75d821507 100644 --- a/src/test/java/org/opensearch/ad/model/AnomalyDetectorJobTests.java +++ b/src/test/java/org/opensearch/ad/model/AnomalyDetectorJobTests.java @@ -15,21 +15,21 @@ import java.util.Collection; import java.util.Locale; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; -import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.plugins.Plugin; import org.opensearch.test.InternalSettingsPlugin; import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; public class AnomalyDetectorJobTests extends OpenSearchSingleNodeTestCase { @Override protected Collection> getPlugins() { - return pluginList(InternalSettingsPlugin.class, AnomalyDetectorPlugin.class); + return pluginList(InternalSettingsPlugin.class, TimeSeriesAnalyticsPlugin.class); } @Override diff --git a/src/test/java/org/opensearch/ad/model/AnomalyDetectorJobTests.java-e b/src/test/java/org/opensearch/ad/model/AnomalyDetectorJobTests.java-e new file mode 100644 index 000000000..75d821507 --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/AnomalyDetectorJobTests.java-e @@ -0,0 +1,59 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.io.IOException; +import java.util.Collection; +import java.util.Locale; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.plugins.Plugin; +import org.opensearch.test.InternalSettingsPlugin; +import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +public class AnomalyDetectorJobTests extends OpenSearchSingleNodeTestCase { + + @Override + protected Collection> getPlugins() { + return pluginList(InternalSettingsPlugin.class, TimeSeriesAnalyticsPlugin.class); + } + + @Override + protected NamedWriteableRegistry writableRegistry() { + return getInstanceFromNode(NamedWriteableRegistry.class); + } + + public void testParseAnomalyDetectorJob() throws IOException { + AnomalyDetectorJob anomalyDetectorJob = TestHelpers.randomAnomalyDetectorJob(); + String anomalyDetectorJobString = TestHelpers + .xContentBuilderToString(anomalyDetectorJob.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + anomalyDetectorJobString = anomalyDetectorJobString + .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); + + AnomalyDetectorJob parsedAnomalyDetectorJob = AnomalyDetectorJob.parse(TestHelpers.parser(anomalyDetectorJobString)); + assertEquals("Parsing anomaly detect result doesn't work", anomalyDetectorJob, parsedAnomalyDetectorJob); + } + + public void testSerialization() throws IOException { + AnomalyDetectorJob anomalyDetectorJob = TestHelpers.randomAnomalyDetectorJob(); + BytesStreamOutput output = new BytesStreamOutput(); + anomalyDetectorJob.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + AnomalyDetectorJob parsedAnomalyDetectorJob = new AnomalyDetectorJob(input); + assertNotNull(parsedAnomalyDetectorJob); + } +} diff --git a/src/test/java/org/opensearch/ad/model/AnomalyDetectorSerializationTests.java b/src/test/java/org/opensearch/ad/model/AnomalyDetectorSerializationTests.java index 2d9f5baf9..aa32bf495 100644 --- a/src/test/java/org/opensearch/ad/model/AnomalyDetectorSerializationTests.java +++ b/src/test/java/org/opensearch/ad/model/AnomalyDetectorSerializationTests.java @@ -15,14 +15,14 @@ import java.time.Instant; import java.util.Collection; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; -import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.plugins.Plugin; import org.opensearch.test.InternalSettingsPlugin; import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -30,7 +30,7 @@ public class AnomalyDetectorSerializationTests extends OpenSearchSingleNodeTestCase { @Override protected Collection> getPlugins() { - return pluginList(InternalSettingsPlugin.class, AnomalyDetectorPlugin.class); + return pluginList(InternalSettingsPlugin.class, TimeSeriesAnalyticsPlugin.class); } @Override diff --git a/src/test/java/org/opensearch/ad/model/AnomalyDetectorSerializationTests.java-e b/src/test/java/org/opensearch/ad/model/AnomalyDetectorSerializationTests.java-e new file mode 100644 index 000000000..aa32bf495 --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/AnomalyDetectorSerializationTests.java-e @@ -0,0 +1,78 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collection; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.plugins.Plugin; +import org.opensearch.test.InternalSettingsPlugin; +import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +public class AnomalyDetectorSerializationTests extends OpenSearchSingleNodeTestCase { + @Override + protected Collection> getPlugins() { + return pluginList(InternalSettingsPlugin.class, TimeSeriesAnalyticsPlugin.class); + } + + @Override + protected NamedWriteableRegistry writableRegistry() { + return getInstanceFromNode(NamedWriteableRegistry.class); + } + + public void testDetectorWithUiMetadata() throws IOException { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of("testKey", "testValue"), Instant.now()); + BytesStreamOutput output = new BytesStreamOutput(); + detector.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + AnomalyDetector parsedDetector = new AnomalyDetector(input); + assertTrue(parsedDetector.equals(detector)); + } + + public void testDetectorWithoutUiMetadata() throws IOException { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(null, Instant.now()); + BytesStreamOutput output = new BytesStreamOutput(); + detector.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + AnomalyDetector parsedDetector = new AnomalyDetector(input); + assertTrue(parsedDetector.equals(detector)); + } + + public void testHCDetector() throws IOException { + AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields("testId", ImmutableList.of("category_field")); + BytesStreamOutput output = new BytesStreamOutput(); + detector.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + AnomalyDetector parsedDetector = new AnomalyDetector(input); + assertTrue(parsedDetector.equals(detector)); + } + + public void testWithoutUser() throws IOException { + AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields("testId", ImmutableList.of("category_field")); + detector.setUser(null); + BytesStreamOutput output = new BytesStreamOutput(); + detector.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + AnomalyDetector parsedDetector = new AnomalyDetector(input); + assertTrue(parsedDetector.equals(detector)); + } + +} diff --git a/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java-e b/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java-e new file mode 100644 index 000000000..d3298eae2 --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java-e @@ -0,0 +1,678 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import static org.opensearch.ad.constant.ADCommonMessages.INVALID_RESULT_INDEX_PREFIX; +import static org.opensearch.ad.constant.ADCommonName.CUSTOM_RESULT_INDEX_PREFIX; +import static org.opensearch.ad.model.AnomalyDetector.MAX_RESULT_INDEX_NAME_SIZE; +import static org.opensearch.timeseries.constant.CommonMessages.INVALID_CHAR_IN_RESULT_INDEX_NAME; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Locale; +import java.util.concurrent.TimeUnit; + +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +public class AnomalyDetectorTests extends AbstractTimeSeriesTest { + + public void testParseAnomalyDetector() throws IOException { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), Instant.now()); + String detectorString = TestHelpers.xContentBuilderToString(detector.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + LOG.info(detectorString); + detectorString = detectorString + .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); + AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString)); + assertEquals("Parsing anomaly detector doesn't work", detector, parsedDetector); + } + + public void testParseAnomalyDetectorWithCustomIndex() throws IOException { + String resultIndex = ADCommonName.CUSTOM_RESULT_INDEX_PREFIX + "test"; + AnomalyDetector detector = TestHelpers + .randomDetector( + ImmutableList.of(TestHelpers.randomFeature()), + randomAlphaOfLength(5), + randomIntBetween(1, 5), + randomAlphaOfLength(5), + ImmutableList.of(randomAlphaOfLength(5)), + resultIndex + ); + String detectorString = TestHelpers.xContentBuilderToString(detector.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + LOG.info(detectorString); + detectorString = detectorString + .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); + AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString)); + assertEquals("Parsing result index doesn't work", resultIndex, parsedDetector.getCustomResultIndex()); + assertEquals("Parsing anomaly detector doesn't work", detector, parsedDetector); + } + + public void testAnomalyDetectorWithInvalidCustomIndex() throws Exception { + String resultIndex = ADCommonName.CUSTOM_RESULT_INDEX_PREFIX + "test@@"; + TestHelpers + .assertFailWith( + ValidationException.class, + () -> (TestHelpers + .randomDetector( + ImmutableList.of(TestHelpers.randomFeature()), + randomAlphaOfLength(5), + randomIntBetween(1, 5), + randomAlphaOfLength(5), + ImmutableList.of(randomAlphaOfLength(5)), + resultIndex + )) + ); + } + + public void testParseAnomalyDetectorWithoutParams() throws IOException { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), Instant.now()); + String detectorString = TestHelpers.xContentBuilderToString(detector.toXContent(TestHelpers.builder())); + LOG.info(detectorString); + detectorString = detectorString + .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); + AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString)); + assertEquals("Parsing anomaly detector doesn't work", detector, parsedDetector); + } + + public void testParseAnomalyDetectorWithCustomDetectionDelay() throws IOException { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), Instant.now()); + String detectorString = TestHelpers.xContentBuilderToString(detector.toXContent(TestHelpers.builder())); + LOG.info(detectorString); + TimeValue detectionInterval = new TimeValue(1, TimeUnit.MINUTES); + TimeValue detectionWindowDelay = new TimeValue(10, TimeUnit.MINUTES); + detectorString = detectorString + .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); + AnomalyDetector parsedDetector = AnomalyDetector + .parse(TestHelpers.parser(detectorString), detector.getId(), detector.getVersion(), detectionInterval, detectionWindowDelay); + assertEquals("Parsing anomaly detector doesn't work", detector, parsedDetector); + } + + public void testParseSingleEntityAnomalyDetector() throws IOException { + AnomalyDetector detector = TestHelpers + .randomAnomalyDetector( + ImmutableList.of(TestHelpers.randomFeature()), + TestHelpers.randomUiMetadata(), + Instant.now(), + AnomalyDetectorType.SINGLE_ENTITY.name() + ); + String detectorString = TestHelpers.xContentBuilderToString(detector.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + LOG.info(detectorString); + detectorString = detectorString + .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); + AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString)); + assertEquals("Parsing anomaly detector doesn't work", detector, parsedDetector); + } + + public void testParseHistoricalAnomalyDetectorWithoutUser() throws IOException { + AnomalyDetector detector = TestHelpers + .randomAnomalyDetector( + ImmutableList.of(TestHelpers.randomFeature()), + TestHelpers.randomUiMetadata(), + Instant.now(), + false, + null + ); + String detectorString = TestHelpers.xContentBuilderToString(detector.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + LOG.info(detectorString); + detectorString = detectorString + .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); + AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString)); + assertEquals("Parsing anomaly detector doesn't work", detector, parsedDetector); + } + + public void testParseAnomalyDetectorWithNullFilterQuery() throws IOException { + String detectorString = "{\"name\":\"todagtCMkwpcaedpyYUM\",\"description\":" + + "\"ClrcaMpuLfeDSlVduRcKlqPZyqWDBf\",\"time_field\":\"dJRwh\",\"indices\":[\"eIrgWMqAED\"]," + + "\"feature_attributes\":[{\"feature_id\":\"lxYRN\",\"feature_name\":\"eqSeU\",\"feature_enabled\"" + + ":true,\"aggregation_query\":{\"aa\":{\"value_count\":{\"field\":\"ok\"}}}}],\"detection_interval\":" + + "{\"period\":{\"interval\":425,\"unit\":\"Minutes\"}},\"window_delay\":{\"period\":{\"interval\":973," + + "\"unit\":\"Minutes\"}},\"shingle_size\":4,\"schema_version\":-1203962153,\"ui_metadata\":{\"JbAaV\":{\"feature_id\":" + + "\"rIFjS\",\"feature_name\":\"QXCmS\",\"feature_enabled\":false,\"aggregation_query\":{\"aa\":" + + "{\"value_count\":{\"field\":\"ok\"}}}}},\"last_update_time\":1568396089028}"; + AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString)); + assertTrue(parsedDetector.getFilterQuery() instanceof MatchAllQueryBuilder); + } + + public void testParseAnomalyDetectorWithEmptyFilterQuery() throws IOException { + String detectorString = "{\"name\":\"todagtCMkwpcaedpyYUM\",\"description\":" + + "\"ClrcaMpuLfeDSlVduRcKlqPZyqWDBf\",\"time_field\":\"dJRwh\",\"indices\":[\"eIrgWMqAED\"]," + + "\"feature_attributes\":[{\"feature_id\":\"lxYRN\",\"feature_name\":\"eqSeU\",\"feature_enabled\":" + + "true,\"aggregation_query\":{\"aa\":{\"value_count\":{\"field\":\"ok\"}}}}],\"filter_query\":{}," + + "\"detection_interval\":{\"period\":{\"interval\":425,\"unit\":\"Minutes\"}},\"window_delay\":" + + "{\"period\":{\"interval\":973,\"unit\":\"Minutes\"}},\"shingle_size\":4,\"schema_version\":-1203962153,\"ui_metadata\":" + + "{\"JbAaV\":{\"feature_id\":\"rIFjS\",\"feature_name\":\"QXCmS\",\"feature_enabled\":false," + + "\"aggregation_query\":{\"aa\":{\"value_count\":{\"field\":\"ok\"}}}}}," + + "\"last_update_time\":1568396089028}"; + AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString)); + assertTrue(parsedDetector.getFilterQuery() instanceof MatchAllQueryBuilder); + } + + public void testParseAnomalyDetectorWithWrongFilterQuery() throws Exception { + String detectorString = "{\"name\":\"todagtCMkwpcaedpyYUM\",\"description\":" + + "\"ClrcaMpuLfeDSlVduRcKlqPZyqWDBf\",\"time_field\":\"dJRwh\",\"indices\":[\"eIrgWMqAED\"]," + + "\"feature_attributes\":[{\"feature_id\":\"lxYRN\",\"feature_name\":\"eqSeU\",\"feature_enabled\":" + + "true,\"aggregation_query\":{\"aa\":{\"value_count\":{\"field\":\"ok\"}}}}],\"filter_query\":" + + "{\"aa\":\"bb\"},\"detection_interval\":{\"period\":{\"interval\":425,\"unit\":\"Minutes\"}}," + + "\"window_delay\":{\"period\":{\"interval\":973,\"unit\":\"Minutes\"}},\"shingle_size\":4,\"schema_version\":" + + "-1203962153,\"ui_metadata\":{\"JbAaV\":{\"feature_id\":\"rIFjS\",\"feature_name\":\"QXCmS\"," + + "\"feature_enabled\":false,\"aggregation_query\":{\"aa\":{\"value_count\":{\"field\":\"ok\"}}}}}," + + "\"last_update_time\":1568396089028}"; + TestHelpers.assertFailWith(ValidationException.class, () -> AnomalyDetector.parse(TestHelpers.parser(detectorString))); + } + + public void testParseAnomalyDetectorWithoutOptionalParams() throws IOException { + String detectorString = "{\"name\":\"todagtCMkwpcaedpyYUM\",\"description\":" + + "\"ClrcaMpuLfeDSlVduRcKlqPZyqWDBf\",\"time_field\":\"dJRwh\",\"indices\":[\"eIrgWMqAED\"]," + + "\"feature_attributes\":[{\"feature_id\":\"lxYRN\",\"feature_name\":\"eqSeU\",\"feature_enabled\"" + + ":true,\"aggregation_query\":{\"aa\":{\"value_count\":{\"field\":\"ok\"}}}}],\"detection_interval\":" + + "{\"period\":{\"interval\":425,\"unit\":\"Minutes\"}},\"schema_version\":-1203962153,\"ui_metadata\":" + + "{\"JbAaV\":{\"feature_id\":\"rIFjS\",\"feature_name\":\"QXCmS\",\"feature_enabled\":false," + + "\"aggregation_query\":{\"aa\":{\"value_count\":{\"field\":\"ok\"}}}}},\"last_update_time\":1568396089028}"; + AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString), "id", 1L, null, null); + assertTrue(parsedDetector.getFilterQuery() instanceof MatchAllQueryBuilder); + assertEquals((long) parsedDetector.getShingleSize(), (long) TimeSeriesSettings.DEFAULT_SHINGLE_SIZE); + } + + public void testParseAnomalyDetectorWithInvalidShingleSize() throws Exception { + String detectorString = "{\"name\":\"todagtCMkwpcaedpyYUM\",\"description\":" + + "\"ClrcaMpuLfeDSlVduRcKlqPZyqWDBf\",\"time_field\":\"dJRwh\",\"indices\":[\"eIrgWMqAED\"]," + + "\"feature_attributes\":[{\"feature_id\":\"lxYRN\",\"feature_name\":\"eqSeU\",\"feature_enabled\"" + + ":true,\"aggregation_query\":{\"aa\":{\"value_count\":{\"field\":\"ok\"}}}}],\"detection_interval\":" + + "{\"period\":{\"interval\":425,\"unit\":\"Minutes\"}},\"shingle_size\":-1,\"schema_version\":-1203962153,\"ui_metadata\":" + + "{\"JbAaV\":{\"feature_id\":\"rIFjS\",\"feature_name\":\"QXCmS\",\"feature_enabled\":false," + + "\"aggregation_query\":{\"aa\":{\"value_count\":{\"field\":\"ok\"}}}}},\"last_update_time\":1568396089028}"; + TestHelpers.assertFailWith(ValidationException.class, () -> AnomalyDetector.parse(TestHelpers.parser(detectorString))); + } + + public void testParseAnomalyDetectorWithNegativeWindowDelay() throws Exception { + String detectorString = "{\"name\":\"todagtCMkwpcaedpyYUM\",\"description\":" + + "\"ClrcaMpuLfeDSlVduRcKlqPZyqWDBf\",\"time_field\":\"dJRwh\",\"indices\":[\"eIrgWMqAED\"]," + + "\"feature_attributes\":[{\"feature_id\":\"lxYRN\",\"feature_name\":\"eqSeU\",\"feature_enabled\"" + + ":true,\"aggregation_query\":{\"aa\":{\"value_count\":{\"field\":\"ok\"}}}}],\"detection_interval\":" + + "{\"period\":{\"interval\":425,\"unit\":\"Minutes\"}},\"window_delay\":{\"period\":{\"interval\":-973," + + "\"unit\":\"Minutes\"}},\"shingle_size\":4,\"schema_version\":-1203962153,\"ui_metadata\":{\"JbAaV\":{\"feature_id\":" + + "\"rIFjS\",\"feature_name\":\"QXCmS\",\"feature_enabled\":false,\"aggregation_query\":{\"aa\":" + + "{\"value_count\":{\"field\":\"ok\"}}}}},\"last_update_time\":1568396089028}"; + TestHelpers.assertFailWith(ValidationException.class, () -> AnomalyDetector.parse(TestHelpers.parser(detectorString))); + } + + public void testParseAnomalyDetectorWithNegativeDetectionInterval() throws Exception { + String detectorString = "{\"name\":\"todagtCMkwpcaedpyYUM\",\"description\":" + + "\"ClrcaMpuLfeDSlVduRcKlqPZyqWDBf\",\"time_field\":\"dJRwh\",\"indices\":[\"eIrgWMqAED\"]," + + "\"feature_attributes\":[{\"feature_id\":\"lxYRN\",\"feature_name\":\"eqSeU\",\"feature_enabled\"" + + ":true,\"aggregation_query\":{\"aa\":{\"value_count\":{\"field\":\"ok\"}}}}],\"detection_interval\":" + + "{\"period\":{\"interval\":-425,\"unit\":\"Minutes\"}},\"window_delay\":{\"period\":{\"interval\":973," + + "\"unit\":\"Minutes\"}},\"shingle_size\":4,\"schema_version\":-1203962153,\"ui_metadata\":{\"JbAaV\":{\"feature_id\":" + + "\"rIFjS\",\"feature_name\":\"QXCmS\",\"feature_enabled\":false,\"aggregation_query\":{\"aa\":" + + "{\"value_count\":{\"field\":\"ok\"}}}}},\"last_update_time\":1568396089028}"; + TestHelpers.assertFailWith(ValidationException.class, () -> AnomalyDetector.parse(TestHelpers.parser(detectorString))); + } + + public void testParseAnomalyDetectorWithIncorrectFeatureQuery() throws Exception { + String detectorString = "{\"name\":\"todagdpyYUM\",\"description\":" + + "\"ClrcaMpuLfeDSlVduRcKlqPZyqWDBf\",\"time_field\":\"dJRwh\",\"indices\":[\"eIrgWMqAED\"]," + + "\"feature_attributes\":[{\"feature_id\":\"lxYRN\",\"feature_name\":\"eqSeU\",\"feature_enabled\"" + + ":true,\"aggregation_query\":{\"aa\":\"bb\"}}],\"detection_interval\":" + + "{\"period\":{\"interval\":425,\"unit\":\"Minutes\"}},\"window_delay\":{\"period\":{\"interval\":973," + + "\"unit\":\"Minutes\"}},\"shingle_size\":4,\"schema_version\":-1203962153,\"ui_metadata\":{\"JbAaV\":{\"feature_id\":" + + "\"rIFjS\",\"feature_name\":\"QXCmS\",\"feature_enabled\":false,\"aggregation_query\":{\"aa\":" + + "{\"value_count\":{\"field\":\"ok\"}}}}},\"last_update_time\":1568396089028}"; + TestHelpers.assertFailWith(ValidationException.class, () -> AnomalyDetector.parse(TestHelpers.parser(detectorString))); + } + + public void testParseAnomalyDetectorWithInvalidDetectorIntervalUnits() { + String detectorString = "{\"name\":\"todagtCMkwpcaedpyYUM\",\"description\":" + + "\"ClrcaMpuLfeDSlVduRcKlqPZyqWDBf\",\"time_field\":\"dJRwh\",\"indices\":[\"eIrgWMqAED\"]," + + "\"feature_attributes\":[{\"feature_id\":\"lxYRN\",\"feature_name\":\"eqSeU\",\"feature_enabled\"" + + ":true,\"aggregation_query\":{\"aa\":{\"value_count\":{\"field\":\"ok\"}}}}],\"detection_interval\":" + + "{\"period\":{\"interval\":425,\"unit\":\"Millis\"}},\"window_delay\":{\"period\":{\"interval\":973," + + "\"unit\":\"Minutes\"}},\"shingle_size\":4,\"schema_version\":-1203962153,\"ui_metadata\":{\"JbAaV\":{\"feature_id\":" + + "\"rIFjS\",\"feature_name\":\"QXCmS\",\"feature_enabled\":false,\"aggregation_query\":{\"aa\":" + + "{\"value_count\":{\"field\":\"ok\"}}}}},\"last_update_time\":1568396089028}"; + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> AnomalyDetector.parse(TestHelpers.parser(detectorString)) + ); + assertEquals( + String.format(Locale.ROOT, ADCommonMessages.INVALID_TIME_CONFIGURATION_UNITS, ChronoUnit.MILLIS), + exception.getMessage() + ); + } + + public void testParseAnomalyDetectorInvalidWindowDelayUnits() { + String detectorString = "{\"name\":\"todagtCMkwpcaedpyYUM\",\"description\":" + + "\"ClrcaMpuLfeDSlVduRcKlqPZyqWDBf\",\"time_field\":\"dJRwh\",\"indices\":[\"eIrgWMqAED\"]," + + "\"feature_attributes\":[{\"feature_id\":\"lxYRN\",\"feature_name\":\"eqSeU\",\"feature_enabled\"" + + ":true,\"aggregation_query\":{\"aa\":{\"value_count\":{\"field\":\"ok\"}}}}],\"detection_interval\":" + + "{\"period\":{\"interval\":425,\"unit\":\"Minutes\"}},\"window_delay\":{\"period\":{\"interval\":973," + + "\"unit\":\"Millis\"}},\"shingle_size\":4,\"schema_version\":-1203962153,\"ui_metadata\":{\"JbAaV\":{\"feature_id\":" + + "\"rIFjS\",\"feature_name\":\"QXCmS\",\"feature_enabled\":false,\"aggregation_query\":{\"aa\":" + + "{\"value_count\":{\"field\":\"ok\"}}}}},\"last_update_time\":1568396089028}"; + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> AnomalyDetector.parse(TestHelpers.parser(detectorString)) + ); + assertEquals( + String.format(Locale.ROOT, ADCommonMessages.INVALID_TIME_CONFIGURATION_UNITS, ChronoUnit.MILLIS), + exception.getMessage() + ); + } + + public void testParseAnomalyDetectorWithNullUiMetadata() throws IOException { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(null, Instant.now()); + String detectorString = TestHelpers.xContentBuilderToString(detector.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString)); + assertEquals("Parsing anomaly detector doesn't work", detector, parsedDetector); + assertNull(parsedDetector.getUiMetadata()); + } + + public void testParseAnomalyDetectorWithEmptyUiMetadata() throws IOException { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of(), Instant.now()); + String detectorString = TestHelpers.xContentBuilderToString(detector.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString)); + assertEquals("Parsing anomaly detector doesn't work", detector, parsedDetector); + } + + public void testInvalidShingleSize() throws Exception { + TestHelpers + .assertFailWith( + ValidationException.class, + () -> new AnomalyDetector( + randomAlphaOfLength(5), + randomLong(), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + ImmutableList.of(randomAlphaOfLength(5)), + ImmutableList.of(TestHelpers.randomFeature()), + TestHelpers.randomQuery(), + TestHelpers.randomIntervalTimeConfiguration(), + TestHelpers.randomIntervalTimeConfiguration(), + 0, + null, + 1, + Instant.now(), + null, + TestHelpers.randomUser(), + null, + TestHelpers.randomImputationOption() + ) + ); + } + + public void testNullDetectorName() throws Exception { + TestHelpers + .assertFailWith( + ValidationException.class, + () -> new AnomalyDetector( + randomAlphaOfLength(5), + randomLong(), + null, + randomAlphaOfLength(5), + randomAlphaOfLength(5), + ImmutableList.of(randomAlphaOfLength(5)), + ImmutableList.of(TestHelpers.randomFeature()), + TestHelpers.randomQuery(), + TestHelpers.randomIntervalTimeConfiguration(), + TestHelpers.randomIntervalTimeConfiguration(), + TimeSeriesSettings.DEFAULT_SHINGLE_SIZE, + null, + 1, + Instant.now(), + null, + TestHelpers.randomUser(), + null, + TestHelpers.randomImputationOption() + ) + ); + } + + public void testBlankDetectorName() throws Exception { + TestHelpers + .assertFailWith( + ValidationException.class, + () -> new AnomalyDetector( + randomAlphaOfLength(5), + randomLong(), + "", + randomAlphaOfLength(5), + randomAlphaOfLength(5), + ImmutableList.of(randomAlphaOfLength(5)), + ImmutableList.of(TestHelpers.randomFeature()), + TestHelpers.randomQuery(), + TestHelpers.randomIntervalTimeConfiguration(), + TestHelpers.randomIntervalTimeConfiguration(), + TimeSeriesSettings.DEFAULT_SHINGLE_SIZE, + null, + 1, + Instant.now(), + null, + TestHelpers.randomUser(), + null, + TestHelpers.randomImputationOption() + ) + ); + } + + public void testNullTimeField() throws Exception { + TestHelpers + .assertFailWith( + ValidationException.class, + () -> new AnomalyDetector( + randomAlphaOfLength(5), + randomLong(), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + null, + ImmutableList.of(randomAlphaOfLength(5)), + ImmutableList.of(TestHelpers.randomFeature()), + TestHelpers.randomQuery(), + TestHelpers.randomIntervalTimeConfiguration(), + TestHelpers.randomIntervalTimeConfiguration(), + TimeSeriesSettings.DEFAULT_SHINGLE_SIZE, + null, + 1, + Instant.now(), + null, + TestHelpers.randomUser(), + null, + TestHelpers.randomImputationOption() + ) + ); + } + + public void testNullIndices() throws Exception { + TestHelpers + .assertFailWith( + ValidationException.class, + () -> new AnomalyDetector( + randomAlphaOfLength(5), + randomLong(), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + null, + ImmutableList.of(TestHelpers.randomFeature()), + TestHelpers.randomQuery(), + TestHelpers.randomIntervalTimeConfiguration(), + TestHelpers.randomIntervalTimeConfiguration(), + TimeSeriesSettings.DEFAULT_SHINGLE_SIZE, + null, + 1, + Instant.now(), + null, + TestHelpers.randomUser(), + null, + TestHelpers.randomImputationOption() + ) + ); + } + + public void testEmptyIndices() throws Exception { + TestHelpers + .assertFailWith( + ValidationException.class, + () -> new AnomalyDetector( + randomAlphaOfLength(5), + randomLong(), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + ImmutableList.of(), + ImmutableList.of(TestHelpers.randomFeature()), + TestHelpers.randomQuery(), + TestHelpers.randomIntervalTimeConfiguration(), + TestHelpers.randomIntervalTimeConfiguration(), + TimeSeriesSettings.DEFAULT_SHINGLE_SIZE, + null, + 1, + Instant.now(), + null, + TestHelpers.randomUser(), + null, + TestHelpers.randomImputationOption() + ) + ); + } + + public void testNullDetectionInterval() throws Exception { + TestHelpers + .assertFailWith( + ValidationException.class, + () -> new AnomalyDetector( + randomAlphaOfLength(5), + randomLong(), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + ImmutableList.of(randomAlphaOfLength(5)), + ImmutableList.of(TestHelpers.randomFeature()), + TestHelpers.randomQuery(), + null, + TestHelpers.randomIntervalTimeConfiguration(), + TimeSeriesSettings.DEFAULT_SHINGLE_SIZE, + null, + 1, + Instant.now(), + null, + TestHelpers.randomUser(), + null, + TestHelpers.randomImputationOption() + ) + ); + } + + public void testInvalidDetectionInterval() { + ValidationException exception = expectThrows( + ValidationException.class, + () -> new AnomalyDetector( + randomAlphaOfLength(10), + randomLong(), + randomAlphaOfLength(20), + randomAlphaOfLength(30), + randomAlphaOfLength(5), + ImmutableList.of(randomAlphaOfLength(10).toLowerCase(Locale.ROOT)), + ImmutableList.of(TestHelpers.randomFeature()), + TestHelpers.randomQuery(), + new IntervalTimeConfiguration(0, ChronoUnit.MINUTES), + TestHelpers.randomIntervalTimeConfiguration(), + randomIntBetween(1, 20), + null, + randomInt(), + Instant.now(), + null, + null, + null, + TestHelpers.randomImputationOption() + ) + ); + assertEquals("Detection interval must be a positive integer", exception.getMessage()); + } + + public void testInvalidWindowDelay() { + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> new AnomalyDetector( + randomAlphaOfLength(10), + randomLong(), + randomAlphaOfLength(20), + randomAlphaOfLength(30), + randomAlphaOfLength(5), + ImmutableList.of(randomAlphaOfLength(10).toLowerCase(Locale.ROOT)), + ImmutableList.of(TestHelpers.randomFeature()), + TestHelpers.randomQuery(), + new IntervalTimeConfiguration(1, ChronoUnit.MINUTES), + new IntervalTimeConfiguration(-1, ChronoUnit.MINUTES), + randomIntBetween(1, 20), + null, + randomInt(), + Instant.now(), + null, + null, + null, + TestHelpers.randomImputationOption() + ) + ); + assertEquals("Interval -1 should be non-negative", exception.getMessage()); + } + + public void testNullFeatures() throws IOException { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(null, null, Instant.now().truncatedTo(ChronoUnit.SECONDS)); + String detectorString = TestHelpers.xContentBuilderToString(detector.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString)); + assertEquals(0, parsedDetector.getFeatureAttributes().size()); + } + + public void testEmptyFeatures() throws IOException { + AnomalyDetector detector = TestHelpers + .randomAnomalyDetector(ImmutableList.of(), null, Instant.now().truncatedTo(ChronoUnit.SECONDS)); + String detectorString = TestHelpers.xContentBuilderToString(detector.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString)); + assertEquals(0, parsedDetector.getFeatureAttributes().size()); + } + + public void testGetShingleSize() throws IOException { + AnomalyDetector anomalyDetector = new AnomalyDetector( + randomAlphaOfLength(5), + randomLong(), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + ImmutableList.of(randomAlphaOfLength(5)), + ImmutableList.of(TestHelpers.randomFeature()), + TestHelpers.randomQuery(), + TestHelpers.randomIntervalTimeConfiguration(), + TestHelpers.randomIntervalTimeConfiguration(), + 5, + null, + 1, + Instant.now(), + null, + TestHelpers.randomUser(), + null, + TestHelpers.randomImputationOption() + ); + assertEquals((int) anomalyDetector.getShingleSize(), 5); + } + + public void testGetShingleSizeReturnsDefaultValue() throws IOException { + AnomalyDetector anomalyDetector = new AnomalyDetector( + randomAlphaOfLength(5), + randomLong(), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + ImmutableList.of(randomAlphaOfLength(5)), + ImmutableList.of(TestHelpers.randomFeature()), + TestHelpers.randomQuery(), + TestHelpers.randomIntervalTimeConfiguration(), + TestHelpers.randomIntervalTimeConfiguration(), + null, + null, + 1, + Instant.now(), + null, + TestHelpers.randomUser(), + null, + TestHelpers.randomImputationOption() + ); + assertEquals((int) anomalyDetector.getShingleSize(), TimeSeriesSettings.DEFAULT_SHINGLE_SIZE); + } + + public void testNullFeatureAttributes() throws IOException { + AnomalyDetector anomalyDetector = new AnomalyDetector( + randomAlphaOfLength(5), + randomLong(), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + ImmutableList.of(randomAlphaOfLength(5)), + null, + TestHelpers.randomQuery(), + TestHelpers.randomIntervalTimeConfiguration(), + TestHelpers.randomIntervalTimeConfiguration(), + null, + null, + 1, + Instant.now(), + null, + TestHelpers.randomUser(), + null, + TestHelpers.randomImputationOption() + ); + assertNotNull(anomalyDetector.getFeatureAttributes()); + assertEquals(0, anomalyDetector.getFeatureAttributes().size()); + } + + public void testValidateResultIndex() throws IOException { + AnomalyDetector anomalyDetector = new AnomalyDetector( + randomAlphaOfLength(5), + randomLong(), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + ImmutableList.of(randomAlphaOfLength(5)), + ImmutableList.of(TestHelpers.randomFeature()), + TestHelpers.randomQuery(), + TestHelpers.randomIntervalTimeConfiguration(), + TestHelpers.randomIntervalTimeConfiguration(), + null, + null, + 1, + Instant.now(), + null, + TestHelpers.randomUser(), + null, + TestHelpers.randomImputationOption() + ); + + String errorMessage = anomalyDetector.validateCustomResultIndex("abc"); + assertEquals(INVALID_RESULT_INDEX_PREFIX, errorMessage); + + StringBuilder resultIndexNameBuilder = new StringBuilder(CUSTOM_RESULT_INDEX_PREFIX); + for (int i = 0; i < MAX_RESULT_INDEX_NAME_SIZE - CUSTOM_RESULT_INDEX_PREFIX.length(); i++) { + resultIndexNameBuilder.append("a"); + } + assertNull(anomalyDetector.validateCustomResultIndex(resultIndexNameBuilder.toString())); + resultIndexNameBuilder.append("a"); + + errorMessage = anomalyDetector.validateCustomResultIndex(resultIndexNameBuilder.toString()); + assertEquals(AnomalyDetector.INVALID_RESULT_INDEX_NAME_SIZE, errorMessage); + + errorMessage = anomalyDetector.validateCustomResultIndex(CUSTOM_RESULT_INDEX_PREFIX + "abc#"); + assertEquals(INVALID_CHAR_IN_RESULT_INDEX_NAME, errorMessage); + } + + public void testParseAnomalyDetectorWithNoDescription() throws IOException { + String detectorString = "{\"name\":\"todagtCMkwpcaedpyYUM\",\"time_field\":\"dJRwh\",\"indices\":[\"eIrgWMqAED\"]," + + "\"feature_attributes\":[{\"feature_id\":\"lxYRN\",\"feature_name\":\"eqSeU\",\"feature_enabled\"" + + ":true,\"aggregation_query\":{\"aa\":{\"value_count\":{\"field\":\"ok\"}}}}],\"detection_interval\":" + + "{\"period\":{\"interval\":425,\"unit\":\"Minutes\"}},\"window_delay\":{\"period\":{\"interval\":973," + + "\"unit\":\"Minutes\"}},\"shingle_size\":4,\"schema_version\":-1203962153,\"ui_metadata\":{\"JbAaV\":{\"feature_id\":" + + "\"rIFjS\",\"feature_name\":\"QXCmS\",\"feature_enabled\":false,\"aggregation_query\":{\"aa\":" + + "{\"value_count\":{\"field\":\"ok\"}}}}},\"last_update_time\":1568396089028}"; + AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString), "id", 1L, null, null); + assertEquals(parsedDetector.getDescription(), ""); + } +} diff --git a/src/test/java/org/opensearch/ad/model/AnomalyResultBucketTests.java b/src/test/java/org/opensearch/ad/model/AnomalyResultBucketTests.java index ba4f2084b..ec2daee06 100644 --- a/src/test/java/org/opensearch/ad/model/AnomalyResultBucketTests.java +++ b/src/test/java/org/opensearch/ad/model/AnomalyResultBucketTests.java @@ -12,8 +12,8 @@ import java.util.Map; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; diff --git a/src/test/java/org/opensearch/ad/model/AnomalyResultBucketTests.java-e b/src/test/java/org/opensearch/ad/model/AnomalyResultBucketTests.java-e new file mode 100644 index 000000000..629708453 --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/AnomalyResultBucketTests.java-e @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.model; + +import static org.opensearch.ad.model.AnomalyResultBucket.*; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; + +public class AnomalyResultBucketTests extends AbstractTimeSeriesTest { + + public void testSerializeAnomalyResultBucket() throws IOException { + AnomalyResultBucket anomalyResultBucket = TestHelpers.randomAnomalyResultBucket(); + BytesStreamOutput output = new BytesStreamOutput(); + anomalyResultBucket.writeTo(output); + StreamInput input = output.bytes().streamInput(); + AnomalyResultBucket parsedAnomalyResultBucket = new AnomalyResultBucket(input); + assertTrue(parsedAnomalyResultBucket.equals(anomalyResultBucket)); + } + + public void testAnomalyResultBucketEquals() { + Map keyOne = new HashMap<>(); + keyOne.put("test-field-1", "test-value-1"); + Map keyTwo = new HashMap<>(); + keyTwo.put("test-field-2", "test-value-2"); + AnomalyResultBucket testBucketOne = new AnomalyResultBucket(keyOne, 3, 0.5); + AnomalyResultBucket testBucketTwo = new AnomalyResultBucket(keyOne, 5, 0.75); + AnomalyResultBucket testBucketThree = new AnomalyResultBucket(keyTwo, 7, 0.2); + assertFalse(testBucketOne.equals(testBucketTwo)); + assertFalse(testBucketTwo.equals(testBucketThree)); + } + + @SuppressWarnings("unchecked") + public void testToXContent() throws IOException { + Map key = new HashMap<>() { + { + put("test-field-1", "test-value-1"); + } + }; + int docCount = 5; + double maxAnomalyGrade = 0.5; + AnomalyResultBucket testBucket = new AnomalyResultBucket(key, docCount, maxAnomalyGrade); + XContentBuilder builder = XContentFactory.jsonBuilder(); + testBucket.toXContent(builder, ToXContent.EMPTY_PARAMS); + XContentParser parser = createParser(builder); + Map parsedMap = parser.map(); + + assertEquals(testBucket.getKey().get("test-field-1"), ((Map) parsedMap.get(KEY_FIELD)).get("test-field-1")); + assertEquals(testBucket.getDocCount(), parsedMap.get(DOC_COUNT_FIELD)); + assertEquals(maxAnomalyGrade, (Double) parsedMap.get(MAX_ANOMALY_GRADE_FIELD), 0.000001d); + } +} diff --git a/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java b/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java index e5c2e7b41..424de19da 100644 --- a/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java +++ b/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java @@ -17,15 +17,15 @@ import java.util.Collection; import java.util.Locale; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; -import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.plugins.Plugin; import org.opensearch.test.InternalSettingsPlugin; import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import com.google.common.base.Objects; @@ -33,7 +33,7 @@ public class AnomalyResultTests extends OpenSearchSingleNodeTestCase { @Override protected Collection> getPlugins() { - return pluginList(InternalSettingsPlugin.class, AnomalyDetectorPlugin.class); + return pluginList(InternalSettingsPlugin.class, TimeSeriesAnalyticsPlugin.class); } @Override diff --git a/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java-e b/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java-e new file mode 100644 index 000000000..424de19da --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java-e @@ -0,0 +1,157 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import static org.opensearch.test.OpenSearchTestCase.randomDouble; + +import java.io.IOException; +import java.util.Collection; +import java.util.Locale; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.plugins.Plugin; +import org.opensearch.test.InternalSettingsPlugin; +import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +import com.google.common.base.Objects; + +public class AnomalyResultTests extends OpenSearchSingleNodeTestCase { + + @Override + protected Collection> getPlugins() { + return pluginList(InternalSettingsPlugin.class, TimeSeriesAnalyticsPlugin.class); + } + + @Override + protected NamedWriteableRegistry writableRegistry() { + return getInstanceFromNode(NamedWriteableRegistry.class); + } + + public void testParseAnomalyDetector() throws IOException { + AnomalyResult detectResult = TestHelpers.randomAnomalyDetectResult(0.8, randomAlphaOfLength(5), null); + String detectResultString = TestHelpers + .xContentBuilderToString(detectResult.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + detectResultString = detectResultString + .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); + AnomalyResult parsedDetectResult = AnomalyResult.parse(TestHelpers.parser(detectResultString)); + assertEquals("Parsing anomaly detect result doesn't work", detectResult, parsedDetectResult); + } + + public void testParseAnomalyDetectorWithoutUser() throws IOException { + AnomalyResult detectResult = TestHelpers.randomAnomalyDetectResult(0.8, randomAlphaOfLength(5), randomAlphaOfLength(5), false); + String detectResultString = TestHelpers + .xContentBuilderToString(detectResult.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + detectResultString = detectResultString + .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); + AnomalyResult parsedDetectResult = AnomalyResult.parse(TestHelpers.parser(detectResultString)); + assertEquals("Parsing anomaly detect result doesn't work", detectResult, parsedDetectResult); + } + + public void testParseAnomalyDetectorWithoutNormalResult() throws IOException { + AnomalyResult detectResult = TestHelpers.randomHCADAnomalyDetectResult(randomDouble(), randomDouble(), null); + + String detectResultString = TestHelpers + .xContentBuilderToString(detectResult.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + detectResultString = detectResultString + .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); + AnomalyResult parsedDetectResult = AnomalyResult.parse(TestHelpers.parser(detectResultString)); + assertTrue( + Objects.equal(detectResult.getConfigId(), parsedDetectResult.getConfigId()) + && Objects.equal(detectResult.getTaskId(), parsedDetectResult.getTaskId()) + && Objects.equal(detectResult.getAnomalyScore(), parsedDetectResult.getAnomalyScore()) + && Objects.equal(detectResult.getAnomalyGrade(), parsedDetectResult.getAnomalyGrade()) + && Objects.equal(detectResult.getConfidence(), parsedDetectResult.getConfidence()) + && Objects.equal(detectResult.getDataStartTime(), parsedDetectResult.getDataStartTime()) + && Objects.equal(detectResult.getDataEndTime(), parsedDetectResult.getDataEndTime()) + && Objects.equal(detectResult.getExecutionStartTime(), parsedDetectResult.getExecutionStartTime()) + && Objects.equal(detectResult.getExecutionEndTime(), parsedDetectResult.getExecutionEndTime()) + && Objects.equal(detectResult.getError(), parsedDetectResult.getError()) + && Objects.equal(detectResult.getEntity(), parsedDetectResult.getEntity()) + && Objects.equal(detectResult.getFeatureData(), parsedDetectResult.getFeatureData()) + ); + } + + public void testParseAnomalyDetectorWithNanAnomalyResult() throws IOException { + AnomalyResult detectResult = TestHelpers.randomHCADAnomalyDetectResult(Double.NaN, Double.NaN, null); + String detectResultString = TestHelpers + .xContentBuilderToString(detectResult.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + detectResultString = detectResultString + .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); + AnomalyResult parsedDetectResult = AnomalyResult.parse(TestHelpers.parser(detectResultString)); + assertNull(parsedDetectResult.getAnomalyGrade()); + assertNull(parsedDetectResult.getAnomalyScore()); + assertTrue( + Objects.equal(detectResult.getConfigId(), parsedDetectResult.getConfigId()) + && Objects.equal(detectResult.getTaskId(), parsedDetectResult.getTaskId()) + && Objects.equal(detectResult.getFeatureData(), parsedDetectResult.getFeatureData()) + && Objects.equal(detectResult.getDataStartTime(), parsedDetectResult.getDataStartTime()) + && Objects.equal(detectResult.getDataEndTime(), parsedDetectResult.getDataEndTime()) + && Objects.equal(detectResult.getExecutionStartTime(), parsedDetectResult.getExecutionStartTime()) + && Objects.equal(detectResult.getExecutionEndTime(), parsedDetectResult.getExecutionEndTime()) + && Objects.equal(detectResult.getError(), parsedDetectResult.getError()) + && Objects.equal(detectResult.getEntity(), parsedDetectResult.getEntity()) + && Objects.equal(detectResult.getConfidence(), parsedDetectResult.getConfidence()) + ); + } + + public void testParseAnomalyDetectorWithTaskId() throws IOException { + AnomalyResult detectResult = TestHelpers.randomAnomalyDetectResult(0.8, randomAlphaOfLength(5), randomAlphaOfLength(5)); + String detectResultString = TestHelpers + .xContentBuilderToString(detectResult.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + detectResultString = detectResultString + .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); + AnomalyResult parsedDetectResult = AnomalyResult.parse(TestHelpers.parser(detectResultString)); + assertEquals("Parsing anomaly detect result doesn't work", detectResult, parsedDetectResult); + } + + public void testParseAnomalyDetectorWithEntity() throws IOException { + AnomalyResult detectResult = TestHelpers.randomHCADAnomalyDetectResult(0.8, 0.5); + String detectResultString = TestHelpers + .xContentBuilderToString(detectResult.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + detectResultString = detectResultString + .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); + AnomalyResult parsedDetectResult = AnomalyResult.parse(TestHelpers.parser(detectResultString)); + assertEquals("Parsing anomaly detect result doesn't work", detectResult, parsedDetectResult); + } + + public void testSerializeAnomalyResult() throws IOException { + AnomalyResult detectResult = TestHelpers.randomAnomalyDetectResult(0.8, randomAlphaOfLength(5), randomAlphaOfLength(5)); + BytesStreamOutput output = new BytesStreamOutput(); + detectResult.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + AnomalyResult parsedDetectResult = new AnomalyResult(input); + assertTrue(parsedDetectResult.equals(detectResult)); + } + + public void testSerializeAnomalyResultWithoutUser() throws IOException { + AnomalyResult detectResult = TestHelpers.randomAnomalyDetectResult(0.8, randomAlphaOfLength(5), randomAlphaOfLength(5), false); + BytesStreamOutput output = new BytesStreamOutput(); + detectResult.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + AnomalyResult parsedDetectResult = new AnomalyResult(input); + assertTrue(parsedDetectResult.equals(detectResult)); + } + + public void testSerializeAnomalyResultWithEntity() throws IOException { + AnomalyResult detectResult = TestHelpers.randomHCADAnomalyDetectResult(0.8, 0.5); + BytesStreamOutput output = new BytesStreamOutput(); + detectResult.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + AnomalyResult parsedDetectResult = new AnomalyResult(input); + assertTrue(parsedDetectResult.equals(detectResult)); + } +} diff --git a/src/test/java/org/opensearch/ad/model/DetectionDateRangeTests.java b/src/test/java/org/opensearch/ad/model/DetectionDateRangeTests.java index dc563d38c..ab507b027 100644 --- a/src/test/java/org/opensearch/ad/model/DetectionDateRangeTests.java +++ b/src/test/java/org/opensearch/ad/model/DetectionDateRangeTests.java @@ -17,22 +17,22 @@ import java.util.Collection; import java.util.Locale; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; -import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.plugins.Plugin; import org.opensearch.test.InternalSettingsPlugin; import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.model.DateRange; public class DetectionDateRangeTests extends OpenSearchSingleNodeTestCase { @Override protected Collection> getPlugins() { - return pluginList(InternalSettingsPlugin.class, AnomalyDetectorPlugin.class); + return pluginList(InternalSettingsPlugin.class, TimeSeriesAnalyticsPlugin.class); } @Override diff --git a/src/test/java/org/opensearch/ad/model/DetectionDateRangeTests.java-e b/src/test/java/org/opensearch/ad/model/DetectionDateRangeTests.java-e new file mode 100644 index 000000000..ab507b027 --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/DetectionDateRangeTests.java-e @@ -0,0 +1,79 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collection; +import java.util.Locale; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.plugins.Plugin; +import org.opensearch.test.InternalSettingsPlugin; +import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.model.DateRange; + +public class DetectionDateRangeTests extends OpenSearchSingleNodeTestCase { + + @Override + protected Collection> getPlugins() { + return pluginList(InternalSettingsPlugin.class, TimeSeriesAnalyticsPlugin.class); + } + + @Override + protected NamedWriteableRegistry writableRegistry() { + return getInstanceFromNode(NamedWriteableRegistry.class); + } + + public void testParseDetectionDateRangeWithNullStartTime() { + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new DateRange(null, Instant.now())); + assertEquals("Detection data range's start time must not be null", exception.getMessage()); + } + + public void testParseDetectionDateRangeWithNullEndTime() { + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new DateRange(Instant.now(), null)); + assertEquals("Detection data range's end time must not be null", exception.getMessage()); + } + + public void testInvalidDateRange() { + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> new DateRange(Instant.now(), Instant.now().minus(10, ChronoUnit.MINUTES)) + ); + assertEquals("Detection data range's end time must be after start time", exception.getMessage()); + } + + public void testSerializeDetectoinDateRange() throws IOException { + DateRange dateRange = TestHelpers.randomDetectionDateRange(); + BytesStreamOutput output = new BytesStreamOutput(); + dateRange.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + DateRange parsedDateRange = new DateRange(input); + assertTrue(parsedDateRange.equals(dateRange)); + } + + public void testParseDetectionDateRange() throws IOException { + DateRange dateRange = TestHelpers.randomDetectionDateRange(); + String dateRangeString = TestHelpers.xContentBuilderToString(dateRange.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + dateRangeString = dateRangeString + .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); + DateRange parsedDateRange = DateRange.parse(TestHelpers.parser(dateRangeString)); + assertEquals("Parsing detection range doesn't work", dateRange, parsedDateRange); + } + +} diff --git a/src/test/java/org/opensearch/ad/model/DetectorInternalStateTests.java-e b/src/test/java/org/opensearch/ad/model/DetectorInternalStateTests.java-e new file mode 100644 index 000000000..2ea993b72 --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/DetectorInternalStateTests.java-e @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.model; + +import java.io.IOException; +import java.time.Instant; + +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.timeseries.TestHelpers; + +public class DetectorInternalStateTests extends OpenSearchSingleNodeTestCase { + + public void testToXContentDetectorInternalState() throws IOException { + DetectorInternalState internalState = new DetectorInternalState.Builder() + .lastUpdateTime(Instant.ofEpochMilli(100L)) + .error("error-test") + .build(); + String internalStateString = TestHelpers + .xContentBuilderToString(internalState.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + DetectorInternalState parsedInternalState = DetectorInternalState.parse(TestHelpers.parser(internalStateString)); + assertEquals(internalState, parsedInternalState); + } + + public void testClonedDetectorInternalState() throws IOException { + DetectorInternalState originalState = new DetectorInternalState.Builder() + .lastUpdateTime(Instant.ofEpochMilli(100L)) + .error("error-test") + .build(); + DetectorInternalState clonedState = (DetectorInternalState) originalState.clone(); + // parse original InternalState + String internalStateString = TestHelpers + .xContentBuilderToString(originalState.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + DetectorInternalState parsedInternalState = DetectorInternalState.parse(TestHelpers.parser(internalStateString)); + // compare parsed to cloned + assertEquals(clonedState, parsedInternalState); + } +} diff --git a/src/test/java/org/opensearch/ad/model/DetectorProfileTests.java b/src/test/java/org/opensearch/ad/model/DetectorProfileTests.java index b7289ab26..9960a5fe2 100644 --- a/src/test/java/org/opensearch/ad/model/DetectorProfileTests.java +++ b/src/test/java/org/opensearch/ad/model/DetectorProfileTests.java @@ -17,7 +17,7 @@ import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.timeseries.TestHelpers; diff --git a/src/test/java/org/opensearch/ad/model/DetectorProfileTests.java-e b/src/test/java/org/opensearch/ad/model/DetectorProfileTests.java-e new file mode 100644 index 000000000..9960a5fe2 --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/DetectorProfileTests.java-e @@ -0,0 +1,115 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.model.Entity; + +public class DetectorProfileTests extends OpenSearchTestCase { + + private DetectorProfile createRandomDetectorProfile() { + return new DetectorProfile.Builder() + .state(DetectorState.INIT) + .error(randomAlphaOfLength(5)) + .modelProfile( + new ModelProfileOnNode[] { + new ModelProfileOnNode( + randomAlphaOfLength(10), + new ModelProfile( + randomAlphaOfLength(5), + Entity.createSingleAttributeEntity(randomAlphaOfLength(5), randomAlphaOfLength(5)), + randomLong() + ) + ) } + ) + .shingleSize(randomInt()) + .coordinatingNode(randomAlphaOfLength(10)) + .totalSizeInBytes(-1) + .totalEntities(randomLong()) + .activeEntities(randomLong()) + .adTaskProfile( + new ADTaskProfile( + randomAlphaOfLength(5), + randomInt(), + randomLong(), + randomBoolean(), + randomInt(), + randomLong(), + randomAlphaOfLength(5) + ) + ) + .build(); + } + + public void testParseDetectorProfile() throws IOException { + DetectorProfile detectorProfile = createRandomDetectorProfile(); + BytesStreamOutput output = new BytesStreamOutput(); + detectorProfile.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + DetectorProfile parsedDetectorProfile = new DetectorProfile(input); + assertEquals("Detector profile serialization doesn't work", detectorProfile, parsedDetectorProfile); + } + + public void testMergeDetectorProfile() { + DetectorProfile detectorProfileOne = createRandomDetectorProfile(); + DetectorProfile detectorProfileTwo = createRandomDetectorProfile(); + String errorPreMerge = detectorProfileOne.getError(); + detectorProfileOne.merge(detectorProfileTwo); + assertTrue(detectorProfileOne.toString().contains(detectorProfileTwo.getError())); + assertFalse(detectorProfileOne.toString().contains(errorPreMerge)); + assertTrue(detectorProfileOne.toString().contains(detectorProfileTwo.getCoordinatingNode())); + } + + public void testDetectorProfileToXContent() throws IOException { + DetectorProfile detectorProfile = createRandomDetectorProfile(); + String detectorProfileString = TestHelpers.xContentBuilderToString(detectorProfile.toXContent(TestHelpers.builder())); + XContentParser parser = TestHelpers.parser(detectorProfileString); + Map parsedMap = parser.map(); + assertEquals(detectorProfile.getCoordinatingNode(), parsedMap.get("coordinating_node")); + assertEquals(detectorProfile.getState().toString(), parsedMap.get("state")); + assertTrue(parsedMap.get("models").toString().contains(detectorProfile.getModelProfile()[0].getModelId())); + } + + public void testDetectorProfileName() throws IllegalArgumentException { + assertEquals("ad_task", DetectorProfileName.getName(ADCommonName.AD_TASK).getName()); + assertEquals("state", DetectorProfileName.getName(ADCommonName.STATE).getName()); + assertEquals("error", DetectorProfileName.getName(ADCommonName.ERROR).getName()); + assertEquals("coordinating_node", DetectorProfileName.getName(ADCommonName.COORDINATING_NODE).getName()); + assertEquals("shingle_size", DetectorProfileName.getName(ADCommonName.SHINGLE_SIZE).getName()); + assertEquals("total_size_in_bytes", DetectorProfileName.getName(ADCommonName.TOTAL_SIZE_IN_BYTES).getName()); + assertEquals("models", DetectorProfileName.getName(ADCommonName.MODELS).getName()); + assertEquals("init_progress", DetectorProfileName.getName(ADCommonName.INIT_PROGRESS).getName()); + assertEquals("total_entities", DetectorProfileName.getName(ADCommonName.TOTAL_ENTITIES).getName()); + assertEquals("active_entities", DetectorProfileName.getName(ADCommonName.ACTIVE_ENTITIES).getName()); + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> DetectorProfileName.getName("abc")); + assertEquals(exception.getMessage(), ADCommonMessages.UNSUPPORTED_PROFILE_TYPE); + } + + public void testDetectorProfileSet() throws IllegalArgumentException { + DetectorProfile detectorProfileOne = createRandomDetectorProfile(); + detectorProfileOne.setShingleSize(20); + assertEquals(20, detectorProfileOne.getShingleSize()); + detectorProfileOne.setActiveEntities(10L); + assertEquals(10L, (long) detectorProfileOne.getActiveEntities()); + detectorProfileOne.setModelCount(10L); + assertEquals(10L, (long) detectorProfileOne.getActiveEntities()); + } +} diff --git a/src/test/java/org/opensearch/ad/model/EntityAnomalyResultTests.java-e b/src/test/java/org/opensearch/ad/model/EntityAnomalyResultTests.java-e new file mode 100644 index 000000000..24cb0c879 --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/EntityAnomalyResultTests.java-e @@ -0,0 +1,107 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import static java.util.Arrays.asList; +import static org.opensearch.timeseries.TestHelpers.randomHCADAnomalyDetectResult; + +import java.util.ArrayList; +import java.util.List; + +import org.junit.Test; +import org.opensearch.ad.stats.ADStatsResponse; +import org.opensearch.test.OpenSearchTestCase; + +public class EntityAnomalyResultTests extends OpenSearchTestCase { + + @Test + public void testGetAnomalyResults() { + AnomalyResult anomalyResult1 = randomHCADAnomalyDetectResult(0.25, 0.25, "error"); + AnomalyResult anomalyResult2 = randomHCADAnomalyDetectResult(0.5, 0.5, "error"); + List anomalyResults = new ArrayList() { + { + add(anomalyResult1); + add(anomalyResult2); + } + }; + EntityAnomalyResult entityAnomalyResult = new EntityAnomalyResult(anomalyResults); + + assertEquals(anomalyResults, entityAnomalyResult.getAnomalyResults()); + } + + @Test + public void testMerge() { + AnomalyResult anomalyResult1 = randomHCADAnomalyDetectResult(0.25, 0.25, "error"); + AnomalyResult anomalyResult2 = randomHCADAnomalyDetectResult(0.5, 0.5, "error"); + + EntityAnomalyResult entityAnomalyResult1 = new EntityAnomalyResult(new ArrayList() { + { + add(anomalyResult1); + } + }); + EntityAnomalyResult entityAnomalyResult2 = new EntityAnomalyResult(new ArrayList() { + { + add(anomalyResult2); + } + }); + entityAnomalyResult2.merge(entityAnomalyResult1); + + assertEquals(asList(anomalyResult2, anomalyResult1), entityAnomalyResult2.getAnomalyResults()); + } + + @Test + public void testMerge_null() { + AnomalyResult anomalyResult = randomHCADAnomalyDetectResult(0.25, 0.25, "error"); + + EntityAnomalyResult entityAnomalyResult = new EntityAnomalyResult(new ArrayList() { + { + add(anomalyResult); + } + }); + + entityAnomalyResult.merge(null); + + assertEquals(asList(anomalyResult), entityAnomalyResult.getAnomalyResults()); + } + + @Test + public void testMerge_self() { + AnomalyResult anomalyResult = randomHCADAnomalyDetectResult(0.25, 0.25, "error"); + + EntityAnomalyResult entityAnomalyResult = new EntityAnomalyResult(new ArrayList() { + { + add(anomalyResult); + } + }); + + entityAnomalyResult.merge(entityAnomalyResult); + + assertEquals(asList(anomalyResult), entityAnomalyResult.getAnomalyResults()); + } + + @Test + public void testMerge_otherClass() { + ADStatsResponse adStatsResponse = new ADStatsResponse(); + AnomalyResult anomalyResult = randomHCADAnomalyDetectResult(0.25, 0.25, "error"); + + EntityAnomalyResult entityAnomalyResult = new EntityAnomalyResult(new ArrayList() { + { + add(anomalyResult); + } + }); + + entityAnomalyResult.merge(adStatsResponse); + + assertEquals(asList(anomalyResult), entityAnomalyResult.getAnomalyResults()); + } + +} diff --git a/src/test/java/org/opensearch/ad/model/EntityProfileTests.java-e b/src/test/java/org/opensearch/ad/model/EntityProfileTests.java-e new file mode 100644 index 000000000..caa45a2d8 --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/EntityProfileTests.java-e @@ -0,0 +1,71 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; + +import java.io.IOException; + +import org.opensearch.ad.common.exception.JsonPathNotFoundException; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.common.Strings; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.AbstractTimeSeriesTest; + +import test.org.opensearch.ad.util.JsonDeserializer; + +public class EntityProfileTests extends AbstractTimeSeriesTest { + public void testMerge() { + EntityProfile profile1 = new EntityProfile(null, -1, -1, null, null, EntityState.INIT); + EntityProfile profile2 = new EntityProfile(null, -1, -1, null, null, EntityState.UNKNOWN); + profile1.merge(profile2); + assertEquals(profile1.getState(), EntityState.INIT); + assertTrue(profile1.toString().contains(EntityState.INIT.toString())); + } + + public void testToXContent() throws IOException, JsonPathNotFoundException { + EntityProfile profile1 = new EntityProfile(null, -1, -1, null, null, EntityState.INIT); + + XContentBuilder builder = jsonBuilder(); + profile1.toXContent(builder, ToXContent.EMPTY_PARAMS); + String json = Strings.toString(builder); + + assertEquals("INIT", JsonDeserializer.getTextValue(json, ADCommonName.STATE)); + + EntityProfile profile2 = new EntityProfile(null, -1, -1, null, null, EntityState.UNKNOWN); + + builder = jsonBuilder(); + profile2.toXContent(builder, ToXContent.EMPTY_PARAMS); + json = Strings.toString(builder); + + assertTrue(false == JsonDeserializer.hasChildNode(json, ADCommonName.STATE)); + } + + public void testToXContentTimeStampAboveZero() throws IOException, JsonPathNotFoundException { + EntityProfile profile1 = new EntityProfile(null, 1, 1, null, null, EntityState.INIT); + + XContentBuilder builder = jsonBuilder(); + profile1.toXContent(builder, ToXContent.EMPTY_PARAMS); + String json = Strings.toString(builder); + + assertEquals("INIT", JsonDeserializer.getTextValue(json, ADCommonName.STATE)); + + EntityProfile profile2 = new EntityProfile(null, 1, 1, null, null, EntityState.UNKNOWN); + + builder = jsonBuilder(); + profile2.toXContent(builder, ToXContent.EMPTY_PARAMS); + json = Strings.toString(builder); + + assertTrue(false == JsonDeserializer.hasChildNode(json, ADCommonName.STATE)); + } +} diff --git a/src/test/java/org/opensearch/ad/model/EntityTests.java-e b/src/test/java/org/opensearch/ad/model/EntityTests.java-e new file mode 100644 index 000000000..7c645d920 --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/EntityTests.java-e @@ -0,0 +1,45 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.util.Collections; +import java.util.Optional; +import java.util.TreeMap; + +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.model.Entity; + +public class EntityTests extends AbstractTimeSeriesTest { + /** + * Test that toStrign has no random string, but only attributes + */ + public void testToString() { + TreeMap attributes = new TreeMap<>(); + String name1 = "host"; + String val1 = "server_2"; + String name2 = "service"; + String val2 = "app_4"; + attributes.put(name1, val1); + attributes.put(name2, val2); + Entity entity = Entity.createEntityFromOrderedMap(attributes); + assertEquals("host=server_2,service=app_4", entity.toString()); + } + + public void test_getModelId_returnId_withNoAttributes() { + String detectorId = "id"; + Entity entity = Entity.createEntityByReordering(Collections.emptyMap()); + + Optional modelId = entity.getModelId(detectorId); + + assertTrue(!modelId.isPresent()); + } +} diff --git a/src/test/java/org/opensearch/ad/model/FeatureDataTests.java-e b/src/test/java/org/opensearch/ad/model/FeatureDataTests.java-e new file mode 100644 index 000000000..bfd17bb95 --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/FeatureDataTests.java-e @@ -0,0 +1,33 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.io.IOException; +import java.util.Locale; + +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.model.FeatureData; + +public class FeatureDataTests extends OpenSearchTestCase { + + public void testParseAnomalyDetector() throws IOException { + FeatureData featureData = TestHelpers.randomFeatureData(); + String featureDataString = TestHelpers + .xContentBuilderToString(featureData.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + featureDataString = featureDataString + .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); + FeatureData parsedFeatureData = FeatureData.parse(TestHelpers.parser(featureDataString)); + assertEquals("Parsing feature data doesn't work", featureData, parsedFeatureData); + } +} diff --git a/src/test/java/org/opensearch/ad/model/FeatureTests.java-e b/src/test/java/org/opensearch/ad/model/FeatureTests.java-e new file mode 100644 index 000000000..bc3baafe8 --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/FeatureTests.java-e @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.io.IOException; +import java.util.Locale; + +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.model.Feature; + +public class FeatureTests extends OpenSearchTestCase { + + public void testParseFeature() throws IOException { + Feature feature = TestHelpers.randomFeature(); + String featureString = TestHelpers.xContentBuilderToString(feature.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + featureString = featureString + .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); + Feature parsedFeature = Feature.parse(TestHelpers.parser(featureString)); + assertEquals("Parsing feature doesn't work", feature, parsedFeature); + } + + public void testNullName() throws Exception { + TestHelpers + .assertFailWith( + IllegalArgumentException.class, + () -> new Feature(randomAlphaOfLength(5), null, true, TestHelpers.randomAggregation()) + ); + } + +} diff --git a/src/test/java/org/opensearch/ad/model/IntervalTimeConfigurationTests.java-e b/src/test/java/org/opensearch/ad/model/IntervalTimeConfigurationTests.java-e new file mode 100644 index 000000000..970d9fd89 --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/IntervalTimeConfigurationTests.java-e @@ -0,0 +1,73 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.io.IOException; +import java.time.Duration; +import java.time.temporal.ChronoUnit; +import java.util.Locale; + +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.TimeConfiguration; + +public class IntervalTimeConfigurationTests extends OpenSearchTestCase { + + public void testParseIntervalSchedule() throws IOException { + TimeConfiguration schedule = TestHelpers.randomIntervalTimeConfiguration(); + String scheduleString = TestHelpers.xContentBuilderToString(schedule.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + scheduleString = scheduleString + .replaceFirst( + "\"interval", + String.format(Locale.ROOT, "\"%s\":\"%s\",\"interval", randomAlphaOfLength(5), randomAlphaOfLength(5)) + ); + TimeConfiguration parsedSchedule = TimeConfiguration.parse(TestHelpers.parser(scheduleString)); + assertEquals("Parsing interval schedule doesn't work", schedule, parsedSchedule); + } + + public void testParseWrongScheduleType() throws Exception { + TimeConfiguration schedule = TestHelpers.randomIntervalTimeConfiguration(); + String scheduleString = TestHelpers.xContentBuilderToString(schedule.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + String finalScheduleString = scheduleString.replaceFirst("period", randomAlphaOfLength(5)); + TestHelpers + .assertFailWith( + IllegalArgumentException.class, + "Find no schedule definition", + () -> TimeConfiguration.parse(TestHelpers.parser(finalScheduleString)) + ); + } + + public void testWrongInterval() throws Exception { + TestHelpers + .assertFailWith( + IllegalArgumentException.class, + "should be non-negative", + () -> new IntervalTimeConfiguration(randomLongBetween(-100, -1), ChronoUnit.MINUTES) + ); + } + + public void testWrongUnit() throws Exception { + TestHelpers + .assertFailWith( + IllegalArgumentException.class, + "is not supported", + () -> new IntervalTimeConfiguration(randomLongBetween(1, 100), ChronoUnit.MILLIS) + ); + } + + public void testToDuration() { + IntervalTimeConfiguration timeConfig = new IntervalTimeConfiguration(/*interval*/1, ChronoUnit.MINUTES); + assertEquals(Duration.ofMillis(60_000L), timeConfig.toDuration()); + } +} diff --git a/src/test/java/org/opensearch/ad/model/MergeableListTests.java-e b/src/test/java/org/opensearch/ad/model/MergeableListTests.java-e new file mode 100644 index 000000000..f9d794da6 --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/MergeableListTests.java-e @@ -0,0 +1,52 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import java.util.ArrayList; +import java.util.List; + +import org.opensearch.timeseries.AbstractTimeSeriesTest; + +public class MergeableListTests extends AbstractTimeSeriesTest { + + public void testMergeableListGetElements() { + List ls1 = new ArrayList(); + ls1.add("item1"); + ls1.add("item2"); + MergeableList mergeList = new MergeableList<>(ls1); + assertEquals(ls1, mergeList.getElements()); + } + + public void testMergeableListMerge() { + List ls1 = new ArrayList(); + ls1.add("item1"); + ls1.add("item2"); + List ls2 = new ArrayList(); + ls2.add("item3"); + ls2.add("item4"); + MergeableList mergeListOne = new MergeableList<>(ls1); + MergeableList mergeListTwo = new MergeableList<>(ls2); + mergeListOne.merge(mergeListTwo); + assertEquals(4, mergeListOne.getElements().size()); + assertEquals("item3", mergeListOne.getElements().get(2)); + } + + public void testMergeableListFailMerge() { + List ls1 = new ArrayList<>(); + ls1.add("item1"); + ls1.add("item2"); + MergeableList mergeListOne = new MergeableList<>(ls1); + MergeableList mergeListTwo = new MergeableList<>(null); + mergeListOne.merge(mergeListTwo); + assertEquals(2, mergeListOne.getElements().size()); + } +} diff --git a/src/test/java/org/opensearch/ad/model/ModelProfileTests.java-e b/src/test/java/org/opensearch/ad/model/ModelProfileTests.java-e new file mode 100644 index 000000000..d00f0f5c5 --- /dev/null +++ b/src/test/java/org/opensearch/ad/model/ModelProfileTests.java-e @@ -0,0 +1,57 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.model; + +import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; + +import java.io.IOException; + +import org.opensearch.common.Strings; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Entity; + +import test.org.opensearch.ad.util.JsonDeserializer; + +public class ModelProfileTests extends AbstractTimeSeriesTest { + + public void testToXContent() throws IOException { + ModelProfile profile1 = new ModelProfile( + randomAlphaOfLength(5), + Entity.createSingleAttributeEntity(randomAlphaOfLength(5), randomAlphaOfLength(5)), + 0 + ); + XContentBuilder builder = getBuilder(profile1); + String json = Strings.toString(builder); + assertTrue(JsonDeserializer.hasChildNode(json, CommonName.ENTITY_KEY)); + assertFalse(JsonDeserializer.hasChildNode(json, CommonName.MODEL_SIZE_IN_BYTES)); + + ModelProfile profile2 = new ModelProfile(randomAlphaOfLength(5), null, 1); + + builder = getBuilder(profile2); + json = Strings.toString(builder); + + assertFalse(JsonDeserializer.hasChildNode(json, CommonName.ENTITY_KEY)); + assertTrue(JsonDeserializer.hasChildNode(json, CommonName.MODEL_SIZE_IN_BYTES)); + + } + + private XContentBuilder getBuilder(ModelProfile profile) throws IOException { + XContentBuilder builder = jsonBuilder(); + builder.startObject(); + profile.toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject(); + return builder; + } +} diff --git a/src/test/java/org/opensearch/ad/plugin/MockReindexPlugin.java-e b/src/test/java/org/opensearch/ad/plugin/MockReindexPlugin.java-e new file mode 100644 index 000000000..079a93c9f --- /dev/null +++ b/src/test/java/org/opensearch/ad/plugin/MockReindexPlugin.java-e @@ -0,0 +1,178 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.plugin; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionResponse; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.reindex.BulkByScrollResponse; +import org.opensearch.index.reindex.BulkByScrollTask; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.index.reindex.UpdateByQueryAction; +import org.opensearch.index.reindex.UpdateByQueryRequest; +import org.opensearch.plugins.ActionPlugin; +import org.opensearch.plugins.Plugin; +import org.opensearch.search.SearchHit; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableList; + +public class MockReindexPlugin extends Plugin implements ActionPlugin { + + @Override + public List> getActions() { + return Arrays + .asList( + new ActionHandler<>(UpdateByQueryAction.INSTANCE, MockTransportUpdateByQueryAction.class), + new ActionHandler<>(DeleteByQueryAction.INSTANCE, MockTransportDeleteByQueryAction.class) + ); + } + + public static class MockTransportUpdateByQueryAction extends HandledTransportAction { + + @Inject + public MockTransportUpdateByQueryAction(ActionFilters actionFilters, TransportService transportService) { + super(UpdateByQueryAction.NAME, transportService, actionFilters, UpdateByQueryRequest::new); + } + + @Override + protected void doExecute(Task task, UpdateByQueryRequest request, ActionListener listener) { + BulkByScrollResponse response = null; + try { + XContentParser parser = TestHelpers + .parser( + "{\"slice_id\":1,\"total\":2,\"updated\":3,\"created\":0,\"deleted\":0,\"batches\":6," + + "\"version_conflicts\":0,\"noops\":0,\"retries\":{\"bulk\":0,\"search\":10}," + + "\"throttled_millis\":0,\"requests_per_second\":13.0,\"canceled\":\"reasonCancelled\"," + + "\"throttled_until_millis\":14}" + ); + parser.nextToken(); + response = new BulkByScrollResponse( + TimeValue.timeValueMillis(10), + BulkByScrollTask.Status.innerFromXContent(parser), + ImmutableList.of(), + ImmutableList.of(), + false + ); + } catch (IOException exception) { + logger.error(exception); + } + listener.onResponse(response); + } + } + + public static class MockTransportDeleteByQueryAction extends HandledTransportAction { + + private Client client; + + @Inject + public MockTransportDeleteByQueryAction(ActionFilters actionFilters, TransportService transportService, Client client) { + super(DeleteByQueryAction.NAME, transportService, actionFilters, DeleteByQueryRequest::new); + this.client = client; + } + + private class MultiResponsesActionListener implements ActionListener { + private final ActionListener delegate; + private final AtomicInteger collectedResponseCount; + private final AtomicLong maxResponseCount; + private final AtomicBoolean hasFailure; + + MultiResponsesActionListener(ActionListener delegate, long maxResponseCount) { + this.delegate = delegate; + this.collectedResponseCount = new AtomicInteger(0); + this.maxResponseCount = new AtomicLong(maxResponseCount); + this.hasFailure = new AtomicBoolean(false); + } + + @Override + public void onResponse(DeleteResponse deleteResponse) { + if (collectedResponseCount.incrementAndGet() >= maxResponseCount.get()) { + finish(); + } + } + + @Override + public void onFailure(Exception e) { + this.hasFailure.set(true); + if (collectedResponseCount.incrementAndGet() >= maxResponseCount.get()) { + finish(); + } + } + + private void finish() { + if (this.hasFailure.get()) { + this.delegate.onFailure(new RuntimeException("failed to delete old AD tasks")); + } else { + try { + XContentParser parser = TestHelpers + .parser( + "{\"slice_id\":1,\"total\":2,\"updated\":0,\"created\":0,\"deleted\":" + + maxResponseCount + + ",\"batches\":6,\"version_conflicts\":0,\"noops\":0,\"retries\":{\"bulk\":0," + + "\"search\":10},\"throttled_millis\":0,\"requests_per_second\":13.0,\"canceled\":" + + "\"reasonCancelled\",\"throttled_until_millis\":14}" + ); + parser.nextToken(); + BulkByScrollResponse response = new BulkByScrollResponse( + TimeValue.timeValueMillis(10), + BulkByScrollTask.Status.innerFromXContent(parser), + ImmutableList.of(), + ImmutableList.of(), + false + ); + this.delegate.onResponse(response); + } catch (IOException exception) { + this.delegate.onFailure(new RuntimeException("failed to parse BulkByScrollResponse")); + } + } + } + } + + @Override + protected void doExecute(Task task, DeleteByQueryRequest request, ActionListener listener) { + SearchRequest searchRequest = request.getSearchRequest(); + client.search(searchRequest, ActionListener.wrap(r -> { + long totalHits = r.getHits().getTotalHits().value; + MultiResponsesActionListener delegateListener = new MultiResponsesActionListener(listener, totalHits); + Iterator iterator = r.getHits().iterator(); + while (iterator.hasNext()) { + String id = iterator.next().getId(); + DeleteRequest deleteRequest = new DeleteRequest(ADCommonName.DETECTION_STATE_INDEX, id) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.delete(deleteRequest, delegateListener); + } + }, e -> listener.onFailure(e))); + } + } +} diff --git a/src/test/java/org/opensearch/ad/ratelimit/AbstractRateLimitingTest.java-e b/src/test/java/org/opensearch/ad/ratelimit/AbstractRateLimitingTest.java-e new file mode 100644 index 000000000..075cf46c3 --- /dev/null +++ b/src/test/java/org/opensearch/ad/ratelimit/AbstractRateLimitingTest.java-e @@ -0,0 +1,66 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.time.Clock; +import java.time.Instant; +import java.util.Arrays; +import java.util.Optional; + +import org.opensearch.action.ActionListener; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.model.Entity; + +public class AbstractRateLimitingTest extends AbstractTimeSeriesTest { + Clock clock; + AnomalyDetector detector; + NodeStateManager nodeStateManager; + String detectorId; + String categoryField; + Entity entity, entity2, entity3; + + @SuppressWarnings("unchecked") + @Override + public void setUp() throws Exception { + super.setUp(); + + clock = mock(Clock.class); + when(clock.instant()).thenReturn(Instant.now()); + + threadPool = mock(ThreadPool.class); + setUpADThreadPool(threadPool); + + categoryField = "a"; + detectorId = "123"; + detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(categoryField)); + + nodeStateManager = mock(NodeStateManager.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + + entity = Entity.createSingleAttributeEntity(categoryField, "value"); + entity2 = Entity.createSingleAttributeEntity(categoryField, "value2"); + entity3 = Entity.createSingleAttributeEntity(categoryField, "value3"); + } +} diff --git a/src/test/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapterTests.java-e b/src/test/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapterTests.java-e new file mode 100644 index 000000000..d1fe526de --- /dev/null +++ b/src/test/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapterTests.java-e @@ -0,0 +1,129 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Optional; + +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.ml.CheckpointDao; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; + +import test.org.opensearch.ad.util.MLUtil; +import test.org.opensearch.ad.util.RandomModelStateConfig; + +public class CheckPointMaintainRequestAdapterTests extends AbstractRateLimitingTest { + private CacheProvider cache; + private CheckpointDao checkpointDao; + private String indexName; + private Setting checkpointInterval; + private CheckPointMaintainRequestAdapter adapter; + private ModelState state; + private CheckpointMaintainRequest request; + private ClusterService clusterService; + + @Override + public void setUp() throws Exception { + super.setUp(); + cache = mock(CacheProvider.class); + checkpointDao = mock(CheckpointDao.class); + indexName = ADCommonName.CHECKPOINT_INDEX_NAME; + checkpointInterval = AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ; + EntityCache entityCache = mock(EntityCache.class); + when(cache.get()).thenReturn(entityCache); + state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + when(entityCache.getForMaintainance(anyString(), anyString())).thenReturn(Optional.of(state)); + clusterService = mock(ClusterService.class); + ClusterSettings settings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ))) + ); + when(clusterService.getClusterSettings()).thenReturn(settings); + adapter = new CheckPointMaintainRequestAdapter( + cache, + checkpointDao, + indexName, + checkpointInterval, + clock, + clusterService, + Settings.EMPTY + ); + request = new CheckpointMaintainRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity.getModelId(detectorId).get()); + + } + + public void testShouldNotSave() { + when(checkpointDao.shouldSave(any(), anyBoolean(), any(), any())).thenReturn(false); + assertTrue(adapter.convert(request).isEmpty()); + } + + public void testIndexSourceNull() throws IOException { + when(checkpointDao.shouldSave(any(), anyBoolean(), any(), any())).thenReturn(true); + when(checkpointDao.toIndexSource(any())).thenReturn(null); + assertTrue(adapter.convert(request).isEmpty()); + } + + public void testIndexSourceEmpty() throws IOException { + when(checkpointDao.shouldSave(any(), anyBoolean(), any(), any())).thenReturn(true); + when(checkpointDao.toIndexSource(any())).thenReturn(new HashMap()); + assertTrue(adapter.convert(request).isEmpty()); + } + + public void testModelIdEmpty() throws IOException { + when(checkpointDao.shouldSave(any(), anyBoolean(), any(), any())).thenReturn(true); + Map content = new HashMap(); + content.put("a", "b"); + when(checkpointDao.toIndexSource(any())).thenReturn(content); + assertTrue(adapter.convert(new CheckpointMaintainRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, null)).isEmpty()); + } + + public void testNormal() throws IOException { + when(checkpointDao.shouldSave(any(), anyBoolean(), any(), any())).thenReturn(true); + Map content = new HashMap(); + content.put("a", "b"); + when(checkpointDao.toIndexSource(any())).thenReturn(content); + Optional converted = adapter.convert(request); + assertTrue(!converted.isEmpty()); + UpdateRequest updateRequest = converted.get().getUpdateRequest(); + UpdateRequest expectedRequest = new UpdateRequest(indexName, entity.getModelId(detectorId).get()).docAsUpsert(true).doc(content); + assertEquals(updateRequest.docAsUpsert(), expectedRequest.docAsUpsert()); + assertEquals(updateRequest.detectNoop(), expectedRequest.detectNoop()); + assertEquals(updateRequest.fetchSource(), expectedRequest.fetchSource()); + } + + public void testIndexSourceException() throws IOException { + doThrow(IllegalArgumentException.class).when(checkpointDao).toIndexSource(any()); + assertTrue(adapter.convert(request).isEmpty()); + } +} diff --git a/src/test/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorkerTests.java-e b/src/test/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorkerTests.java-e new file mode 100644 index 000000000..cba7e8a45 --- /dev/null +++ b/src/test/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorkerTests.java-e @@ -0,0 +1,165 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Random; + +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.ml.CheckpointDao; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; + +import test.org.opensearch.ad.util.MLUtil; +import test.org.opensearch.ad.util.RandomModelStateConfig; + +public class CheckpointMaintainWorkerTests extends AbstractRateLimitingTest { + ClusterService clusterService; + CheckpointMaintainWorker cpMaintainWorker; + CheckpointWriteWorker writeWorker; + CheckpointMaintainRequest request; + CheckpointMaintainRequest request2; + List requests; + CheckpointDao checkpointDao; + + @Override + public void setUp() throws Exception { + super.setUp(); + clusterService = mock(ClusterService.class); + Settings settings = Settings.builder().put(AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE.getKey(), 1).build(); + ClusterSettings clusterSettings = new ClusterSettings( + settings, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays + .asList( + AnomalyDetectorSettings.AD_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS, + AnomalyDetectorSettings.CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, + AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ + ) + ) + ) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + writeWorker = mock(CheckpointWriteWorker.class); + + CacheProvider cache = mock(CacheProvider.class); + checkpointDao = mock(CheckpointDao.class); + String indexName = ADCommonName.CHECKPOINT_INDEX_NAME; + Setting checkpointInterval = AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ; + EntityCache entityCache = mock(EntityCache.class); + when(cache.get()).thenReturn(entityCache); + ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + when(entityCache.getForMaintainance(anyString(), anyString())).thenReturn(Optional.of(state)); + CheckPointMaintainRequestAdapter adapter = new CheckPointMaintainRequestAdapter( + cache, + checkpointDao, + indexName, + checkpointInterval, + clock, + clusterService, + settings + ); + + // Integer.MAX_VALUE makes a huge heap + cpMaintainWorker = new CheckpointMaintainWorker( + Integer.MAX_VALUE, + AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT, + clusterService, + new Random(42), + mock(ADCircuitBreakerService.class), + threadPool, + settings, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + clock, + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + writeWorker, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + nodeStateManager, + adapter + ); + + request = new CheckpointMaintainRequest(Integer.MAX_VALUE, detectorId, RequestPriority.LOW, entity.getModelId(detectorId).get()); + request2 = new CheckpointMaintainRequest(Integer.MAX_VALUE, detectorId, RequestPriority.LOW, entity2.getModelId(detectorId).get()); + + requests = new ArrayList<>(); + requests.add(request); + requests.add(request2); + + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + + TimeValue value = invocation.getArgument(1); + // since we have only 1 request each time + long expectedExecutionPerRequestMilli = AnomalyDetectorSettings.AD_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS + .getDefault(Settings.EMPTY); + long delay = value.getMillis(); + assertTrue(delay == expectedExecutionPerRequestMilli); + return null; + }).when(threadPool).schedule(any(), any(), any()); + } + + public void testPutRequests() throws IOException { + when(checkpointDao.shouldSave(any(), anyBoolean(), any(), any())).thenReturn(true); + Map content = new HashMap(); + content.put("a", "b"); + when(checkpointDao.toIndexSource(any())).thenReturn(content); + + cpMaintainWorker.putAll(requests); + + verify(writeWorker, times(2)).putAll(any()); + verify(threadPool, times(2)).schedule(any(), any(), any()); + } + + public void testFailtoPut() throws IOException { + when(checkpointDao.shouldSave(any(), anyBoolean(), any(), any())).thenReturn(false); + + cpMaintainWorker.putAll(requests); + + verify(writeWorker, never()).putAll(any()); + verify(threadPool, never()).schedule(any(), any(), any()); + } +} diff --git a/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java index f23dbd484..76090cce9 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java @@ -46,7 +46,6 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.action.get.MultiGetItemResponse; import org.opensearch.action.get.MultiGetResponse; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.ad.caching.CacheProvider; import org.opensearch.ad.caching.EntityCache; @@ -66,13 +65,14 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; +import org.opensearch.core.rest.RestStatus; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.get.GetResult; import org.opensearch.index.seqno.SequenceNumbers; -import org.opensearch.rest.RestStatus; import org.opensearch.threadpool.ThreadPoolStats; import org.opensearch.threadpool.ThreadPoolStats.Stats; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.stats.StatNames; @@ -529,7 +529,7 @@ public void testRetryableException() { public void testRemoveUnusedQueues() { // do nothing when putting a request to keep queues not empty ExecutorService executorService = mock(ExecutorService.class); - when(threadPool.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); + when(threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); worker = new CheckpointReadWorker( Integer.MAX_VALUE, @@ -575,7 +575,7 @@ public void testRemoveUnusedQueues() { private void maintenanceSetup() { // do nothing when putting a request to keep queues not empty ExecutorService executorService = mock(ExecutorService.class); - when(threadPool.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); + when(threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); when(threadPool.stats()).thenReturn(new ThreadPoolStats(new ArrayList())); } diff --git a/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java-e b/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java-e new file mode 100644 index 000000000..6b80082a4 --- /dev/null +++ b/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java-e @@ -0,0 +1,817 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static java.util.AbstractMap.SimpleImmutableEntry; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Random; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.mockito.Mockito; +import org.opensearch.OpenSearchException; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.get.MultiGetItemResponse; +import org.opensearch.action.get.MultiGetResponse; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.CheckpointDao; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.stats.ADStat; +import org.opensearch.ad.stats.ADStats; +import org.opensearch.ad.stats.suppliers.CounterSupplier; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.get.GetResult; +import org.opensearch.index.seqno.SequenceNumbers; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.threadpool.ThreadPoolStats; +import org.opensearch.threadpool.ThreadPoolStats.Stats; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.stats.StatNames; + +import test.org.opensearch.ad.util.MLUtil; +import test.org.opensearch.ad.util.RandomModelStateConfig; + +import com.fasterxml.jackson.core.JsonParseException; + +public class CheckpointReadWorkerTests extends AbstractRateLimitingTest { + CheckpointReadWorker worker; + + CheckpointDao checkpoint; + ClusterService clusterService; + + ModelState state; + + CheckpointWriteWorker checkpointWriteQueue; + ModelManager modelManager; + EntityColdStartWorker coldstartQueue; + ResultWriteWorker resultWriteQueue; + ADIndexManagement anomalyDetectionIndices; + CacheProvider cacheProvider; + EntityCache entityCache; + EntityFeatureRequest request, request2, request3; + ClusterSettings clusterSettings; + ADStats adStats; + + @Override + public void setUp() throws Exception { + super.setUp(); + clusterService = mock(ClusterService.class); + clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays + .asList( + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE + ) + ) + ) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + + checkpoint = mock(CheckpointDao.class); + + Map.Entry entry = new SimpleImmutableEntry(state.getModel(), Instant.now()); + when(checkpoint.processGetResponse(any(), anyString())).thenReturn(Optional.of(entry)); + + checkpointWriteQueue = mock(CheckpointWriteWorker.class); + + modelManager = mock(ModelManager.class); + when(modelManager.processEntityCheckpoint(any(), any(), anyString(), anyString(), anyInt())).thenReturn(state); + when(modelManager.score(any(), anyString(), any())).thenReturn(new ThresholdingResult(0, 1, 0.7)); + + coldstartQueue = mock(EntityColdStartWorker.class); + resultWriteQueue = mock(ResultWriteWorker.class); + anomalyDetectionIndices = mock(ADIndexManagement.class); + + cacheProvider = mock(CacheProvider.class); + entityCache = mock(EntityCache.class); + when(cacheProvider.get()).thenReturn(entityCache); + when(entityCache.hostIfPossible(any(), any())).thenReturn(true); + + Map> statsMap = new HashMap>() { + { + put(StatNames.MODEL_CORRUTPION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + } + }; + + adStats = new ADStats(statsMap); + + // Integer.MAX_VALUE makes a huge heap + worker = new CheckpointReadWorker( + Integer.MAX_VALUE, + AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + clusterService, + new Random(42), + mock(ADCircuitBreakerService.class), + threadPool, + Settings.EMPTY, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + clock, + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + modelManager, + checkpoint, + coldstartQueue, + resultWriteQueue, + nodeStateManager, + anomalyDetectionIndices, + cacheProvider, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + adStats + ); + + request = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity, new double[] { 0 }, 0); + request2 = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity2, new double[] { 0 }, 0); + request3 = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity3, new double[] { 0 }, 0); + } + + static class RegularSetUpConfig { + private final boolean canHostModel; + private final boolean fullModel; + + RegularSetUpConfig(Builder builder) { + this.canHostModel = builder.canHostModel; + this.fullModel = builder.fullModel; + } + + public static class Builder { + boolean canHostModel = true; + boolean fullModel = true; + + Builder canHostModel(boolean canHostModel) { + this.canHostModel = canHostModel; + return this; + } + + Builder fullModel(boolean fullModel) { + this.fullModel = fullModel; + return this; + } + + public RegularSetUpConfig build() { + return new RegularSetUpConfig(this); + } + } + } + + private void regularTestSetUp(RegularSetUpConfig config) { + doAnswer(invocation -> { + MultiGetItemResponse[] items = new MultiGetItemResponse[1]; + items[0] = new MultiGetItemResponse( + new GetResponse( + new GetResult(ADCommonName.CHECKPOINT_INDEX_NAME, entity.getModelId(detectorId).get(), 1, 1, 0, true, null, null, null) + ), + null + ); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new MultiGetResponse(items)); + return null; + }).when(checkpoint).batchRead(any(), any()); + + when(entityCache.hostIfPossible(any(), any())).thenReturn(config.canHostModel); + + state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(config.fullModel).build()); + when(modelManager.processEntityCheckpoint(any(), any(), anyString(), anyString(), anyInt())).thenReturn(state); + if (config.fullModel) { + when(modelManager.getAnomalyResultForEntity(any(), any(), anyString(), any(), anyInt())) + .thenReturn(new ThresholdingResult(0, 1, 1)); + } else { + when(modelManager.getAnomalyResultForEntity(any(), any(), anyString(), any(), anyInt())) + .thenReturn(new ThresholdingResult(0, 0, 0)); + } + + List requests = new ArrayList<>(); + requests.add(request); + worker.putAll(requests); + } + + public void testRegular() { + regularTestSetUp(new RegularSetUpConfig.Builder().build()); + + verify(resultWriteQueue, times(1)).put(any()); + verify(checkpointWriteQueue, never()).write(any(), anyBoolean(), any()); + } + + public void testCannotLoadModel() { + regularTestSetUp(new RegularSetUpConfig.Builder().canHostModel(false).build()); + + verify(resultWriteQueue, times(1)).put(any()); + verify(checkpointWriteQueue, times(1)).write(any(), anyBoolean(), any()); + } + + public void testNoFullModel() { + regularTestSetUp(new RegularSetUpConfig.Builder().fullModel(false).build()); + verify(resultWriteQueue, never()).put(any()); + verify(checkpointWriteQueue, never()).write(any(), anyBoolean(), any()); + } + + public void testIndexNotFound() { + doAnswer(invocation -> { + MultiGetItemResponse[] items = new MultiGetItemResponse[1]; + items[0] = new MultiGetItemResponse( + null, + new MultiGetResponse.Failure( + ADCommonName.CHECKPOINT_INDEX_NAME, + entity.getModelId(detectorId).get(), + new IndexNotFoundException(ADCommonName.CHECKPOINT_INDEX_NAME) + ) + ); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new MultiGetResponse(items)); + return null; + }).when(checkpoint).batchRead(any(), any()); + + worker.put(request); + verify(coldstartQueue, times(1)).put(any()); + verify(entityCache, never()).hostIfPossible(any(), any()); + } + + public void testAllDocNotFound() { + doAnswer(invocation -> { + MultiGetItemResponse[] items = new MultiGetItemResponse[2]; + items[0] = new MultiGetItemResponse( + new GetResponse( + new GetResult( + ADCommonName.CHECKPOINT_INDEX_NAME, + entity.getModelId(detectorId).get(), + SequenceNumbers.UNASSIGNED_SEQ_NO, + SequenceNumbers.UNASSIGNED_PRIMARY_TERM, + 0, + false, + null, + null, + null + ) + ), + null + ); + items[1] = new MultiGetItemResponse( + new GetResponse( + new GetResult( + ADCommonName.CHECKPOINT_INDEX_NAME, + entity2.getModelId(detectorId).get(), + SequenceNumbers.UNASSIGNED_SEQ_NO, + SequenceNumbers.UNASSIGNED_PRIMARY_TERM, + 0, + false, + null, + null, + null + ) + ), + null + ); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new MultiGetResponse(items)); + return null; + }).when(checkpoint).batchRead(any(), any()); + + List requests = new ArrayList<>(); + requests.add(request); + requests.add(request2); + worker.putAll(requests); + + verify(coldstartQueue, times(2)).put(any()); + verify(entityCache, never()).hostIfPossible(any(), any()); + } + + public void testSingleDocNotFound() { + doAnswer(invocation -> { + MultiGetItemResponse[] items = new MultiGetItemResponse[2]; + items[0] = new MultiGetItemResponse( + new GetResponse( + new GetResult(ADCommonName.CHECKPOINT_INDEX_NAME, entity.getModelId(detectorId).get(), 1, 1, 0, true, null, null, null) + ), + null + ); + items[1] = new MultiGetItemResponse( + new GetResponse( + new GetResult( + ADCommonName.CHECKPOINT_INDEX_NAME, + entity2.getModelId(detectorId).get(), + SequenceNumbers.UNASSIGNED_SEQ_NO, + SequenceNumbers.UNASSIGNED_PRIMARY_TERM, + 0, + false, + null, + null, + null + ) + ), + null + ); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new MultiGetResponse(items)); + return null; + }).when(checkpoint).batchRead(any(), any()); + + List requests = new ArrayList<>(); + requests.add(request); + requests.add(request2); + worker.putAll(requests); + verify(coldstartQueue, times(1)).put(any()); + verify(entityCache, times(1)).hostIfPossible(any(), any()); + } + + public void testTimeout() { + AtomicBoolean retried = new AtomicBoolean(); + doAnswer(invocation -> { + MultiGetItemResponse[] items = new MultiGetItemResponse[2]; + if (!retried.get()) { + items[0] = new MultiGetItemResponse( + null, + new MultiGetResponse.Failure( + ADCommonName.CHECKPOINT_INDEX_NAME, + entity.getModelId(detectorId).get(), + new OpenSearchStatusException("blah", RestStatus.REQUEST_TIMEOUT) + ) + ); + items[1] = new MultiGetItemResponse( + null, + new MultiGetResponse.Failure( + ADCommonName.CHECKPOINT_INDEX_NAME, + entity2.getModelId(detectorId).get(), + new OpenSearchStatusException("blah", RestStatus.CONFLICT) + ) + ); + retried.set(true); + } else { + items[0] = new MultiGetItemResponse( + new GetResponse( + new GetResult( + ADCommonName.CHECKPOINT_INDEX_NAME, + entity.getModelId(detectorId).get(), + 1, + 1, + 0, + true, + null, + null, + null + ) + ), + null + ); + items[1] = new MultiGetItemResponse( + new GetResponse( + new GetResult( + ADCommonName.CHECKPOINT_INDEX_NAME, + entity2.getModelId(detectorId).get(), + 1, + 1, + 0, + true, + null, + null, + null + ) + ), + null + ); + } + + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new MultiGetResponse(items)); + return null; + }).when(checkpoint).batchRead(any(), any()); + + List requests = new ArrayList<>(); + requests.add(request); + requests.add(request2); + worker.putAll(requests); + // two retried requests and the original putAll trigger 3 batchRead in total. + // It is possible the two retries requests get combined into one batchRead + verify(checkpoint, Mockito.atLeast(2)).batchRead(any(), any()); + assertTrue(retried.get()); + } + + public void testOverloadedExceptionFromResponse() { + doAnswer(invocation -> { + MultiGetItemResponse[] items = new MultiGetItemResponse[1]; + items[0] = new MultiGetItemResponse( + null, + new MultiGetResponse.Failure( + ADCommonName.CHECKPOINT_INDEX_NAME, + entity.getModelId(detectorId).get(), + new OpenSearchRejectedExecutionException("blah") + ) + ); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new MultiGetResponse(items)); + return null; + }).when(checkpoint).batchRead(any(), any()); + + worker.put(request); + verify(coldstartQueue, never()).put(any()); + verify(entityCache, never()).hostIfPossible(any(), any()); + worker.put(request); + // the 2nd put won't trigger batchRead as we are in cool down mode + verify(checkpoint, times(1)).batchRead(any(), any()); + } + + public void testOverloadedExceptionFromFailure() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new OpenSearchRejectedExecutionException("blah")); + return null; + }).when(checkpoint).batchRead(any(), any()); + + worker.put(request); + verify(coldstartQueue, never()).put(any()); + verify(entityCache, never()).hostIfPossible(any(), any()); + worker.put(request); + // the 2nd put won't trigger batchRead as we are in cool down mode + verify(checkpoint, times(1)).batchRead(any(), any()); + } + + public void testUnexpectedException() { + doAnswer(invocation -> { + MultiGetItemResponse[] items = new MultiGetItemResponse[1]; + items[0] = new MultiGetItemResponse( + null, + new MultiGetResponse.Failure( + ADCommonName.CHECKPOINT_INDEX_NAME, + entity.getModelId(detectorId).get(), + new IllegalArgumentException("blah") + ) + ); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new MultiGetResponse(items)); + return null; + }).when(checkpoint).batchRead(any(), any()); + + worker.put(request); + verify(coldstartQueue, never()).put(any()); + verify(entityCache, never()).hostIfPossible(any(), any()); + } + + public void testRetryableException() { + AtomicBoolean retried = new AtomicBoolean(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + if (retried.get()) { + // not retryable + listener.onFailure(new JsonParseException(null, "blah")); + } else { + // retryable + retried.set(true); + listener.onFailure(new OpenSearchException("blah")); + } + + return null; + }).when(checkpoint).batchRead(any(), any()); + + worker.put(request); + verify(coldstartQueue, never()).put(any()); + verify(entityCache, never()).hostIfPossible(any(), any()); + assertTrue(retried.get()); + } + + public void testRemoveUnusedQueues() { + // do nothing when putting a request to keep queues not empty + ExecutorService executorService = mock(ExecutorService.class); + when(threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); + + worker = new CheckpointReadWorker( + Integer.MAX_VALUE, + AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + clusterService, + new Random(42), + mock(ADCircuitBreakerService.class), + threadPool, + Settings.EMPTY, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + clock, + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + modelManager, + checkpoint, + coldstartQueue, + resultWriteQueue, + nodeStateManager, + anomalyDetectionIndices, + cacheProvider, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + adStats + ); + + regularTestSetUp(new RegularSetUpConfig.Builder().build()); + + assertTrue(!worker.isQueueEmpty()); + assertEquals(CheckpointReadWorker.WORKER_NAME, worker.getWorkerName()); + + // make RequestQueue.expired return true + when(clock.instant()).thenReturn(Instant.now().plusSeconds(AnomalyDetectorSettings.HOURLY_MAINTENANCE.getSeconds() + 1)); + + // removed the expired queue + worker.maintenance(); + + assertTrue(worker.isQueueEmpty()); + } + + private void maintenanceSetup() { + // do nothing when putting a request to keep queues not empty + ExecutorService executorService = mock(ExecutorService.class); + when(threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); + when(threadPool.stats()).thenReturn(new ThreadPoolStats(new ArrayList())); + } + + public void testSettingUpdatable() { + maintenanceSetup(); + + // can host two requests in the queue + worker = new CheckpointReadWorker( + 2000, + 1, + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + clusterService, + new Random(42), + mock(ADCircuitBreakerService.class), + threadPool, + Settings.EMPTY, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + clock, + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + modelManager, + checkpoint, + coldstartQueue, + resultWriteQueue, + nodeStateManager, + anomalyDetectionIndices, + cacheProvider, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + adStats + ); + + List requests = new ArrayList<>(); + requests.add(request); + requests.add(request2); + worker.putAll(requests); + // size not exceeded, thus no effect + worker.maintenance(); + assertTrue(!worker.isQueueEmpty()); + + Settings newSettings = Settings + .builder() + .put(AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT.getKey(), "0.0001") + .build(); + Settings.Builder target = Settings.builder(); + clusterSettings.updateDynamicSettings(newSettings, target, Settings.builder(), "test"); + clusterSettings.applySettings(target.build()); + // size not exceeded after changing setting + worker.maintenance(); + assertTrue(worker.isQueueEmpty()); + } + + public void testOpenCircuitBreaker() { + maintenanceSetup(); + + ADCircuitBreakerService breaker = mock(ADCircuitBreakerService.class); + when(breaker.isOpen()).thenReturn(true); + + worker = new CheckpointReadWorker( + Integer.MAX_VALUE, + AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + clusterService, + new Random(42), + breaker, + threadPool, + Settings.EMPTY, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + clock, + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + modelManager, + checkpoint, + coldstartQueue, + resultWriteQueue, + nodeStateManager, + anomalyDetectionIndices, + cacheProvider, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + checkpointWriteQueue, + adStats + ); + + List requests = new ArrayList<>(); + requests.add(request); + requests.add(request2); + worker.putAll(requests); + + // due to open circuit breaker, removed one request + worker.maintenance(); + assertTrue(!worker.isQueueEmpty()); + + // one request per batch + Settings newSettings = Settings.builder().put(AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE.getKey(), "1").build(); + Settings.Builder target = Settings.builder(); + clusterSettings.updateDynamicSettings(newSettings, target, Settings.builder(), "test"); + clusterSettings.applySettings(target.build()); + + // enable executing requests + setUpADThreadPool(threadPool); + + // listener returns response back and trigger calls to process extra requests + doAnswer(invocation -> { + MultiGetItemResponse[] items = new MultiGetItemResponse[1]; + items[0] = new MultiGetItemResponse( + new GetResponse( + new GetResult(ADCommonName.CHECKPOINT_INDEX_NAME, entity.getModelId(detectorId).get(), 1, 1, 0, true, null, null, null) + ), + null + ); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new MultiGetResponse(items)); + return null; + }).when(checkpoint).batchRead(any(), any()); + + // trigger request execution + worker.put(request3); + assertTrue(worker.isQueueEmpty()); + + // two requests in the queue trigger two batches + verify(checkpoint, times(2)).batchRead(any(), any()); + } + + public void testChangePriority() { + assertEquals(RequestPriority.MEDIUM, request.getPriority()); + RequestPriority newPriority = RequestPriority.HIGH; + request.setPriority(newPriority); + assertEquals(newPriority, request.getPriority()); + } + + public void testDetectorId() { + assertEquals(detectorId, request.getId()); + String newDetectorId = "456"; + request.setDetectorId(newDetectorId); + assertEquals(newDetectorId, request.getId()); + } + + @SuppressWarnings("unchecked") + public void testHostException() throws IOException { + String detectorId2 = "456"; + Entity entity4 = Entity.createSingleAttributeEntity(categoryField, "value4"); + EntityFeatureRequest request4 = new EntityFeatureRequest( + Integer.MAX_VALUE, + detectorId2, + RequestPriority.MEDIUM, + entity4, + new double[] { 0 }, + 0 + ); + + AnomalyDetector detector2 = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId2, Arrays.asList(categoryField)); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector2)); + return null; + }).when(nodeStateManager).getAnomalyDetector(eq(detectorId2), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(eq(detectorId), any(ActionListener.class)); + + doAnswer(invocation -> { + MultiGetItemResponse[] items = new MultiGetItemResponse[2]; + items[0] = new MultiGetItemResponse( + new GetResponse( + new GetResult(ADCommonName.CHECKPOINT_INDEX_NAME, entity.getModelId(detectorId).get(), 1, 1, 0, true, null, null, null) + ), + null + ); + items[1] = new MultiGetItemResponse( + new GetResponse( + new GetResult( + ADCommonName.CHECKPOINT_INDEX_NAME, + entity4.getModelId(detectorId2).get(), + 1, + 1, + 0, + true, + null, + null, + null + ) + ), + null + ); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new MultiGetResponse(items)); + return null; + }).when(checkpoint).batchRead(any(), any()); + + doThrow(LimitExceededException.class).when(entityCache).hostIfPossible(eq(detector2), any()); + + List requests = new ArrayList<>(); + requests.add(request); + requests.add(request4); + worker.putAll(requests); + verify(coldstartQueue, never()).put(any()); + verify(entityCache, times(2)).hostIfPossible(any(), any()); + + verify(nodeStateManager, times(1)).setException(eq(detectorId2), any(LimitExceededException.class)); + verify(nodeStateManager, never()).setException(eq(detectorId), any(LimitExceededException.class)); + } + + public void testFailToScore() { + doAnswer(invocation -> { + MultiGetItemResponse[] items = new MultiGetItemResponse[1]; + items[0] = new MultiGetItemResponse( + new GetResponse( + new GetResult(ADCommonName.CHECKPOINT_INDEX_NAME, entity.getModelId(detectorId).get(), 1, 1, 0, true, null, null, null) + ), + null + ); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(new MultiGetResponse(items)); + return null; + }).when(checkpoint).batchRead(any(), any()); + + state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + when(modelManager.processEntityCheckpoint(any(), any(), anyString(), anyString(), anyInt())).thenReturn(state); + doThrow(new IllegalArgumentException()).when(modelManager).getAnomalyResultForEntity(any(), any(), anyString(), any(), anyInt()); + + List requests = new ArrayList<>(); + requests.add(request); + worker.putAll(requests); + + verify(resultWriteQueue, never()).put(any()); + verify(checkpointWriteQueue, never()).write(any(), anyBoolean(), any()); + verify(coldstartQueue, times(1)).put(any()); + Object val = adStats.getStat(StatNames.MODEL_CORRUTPION_COUNT.getName()).getValue(); + assertEquals(1L, ((Long) val).longValue()); + } +} diff --git a/src/test/java/org/opensearch/ad/ratelimit/CheckpointWriteWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/CheckpointWriteWorkerTests.java index a3afda641..97e8370bf 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/CheckpointWriteWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/CheckpointWriteWorkerTests.java @@ -46,7 +46,6 @@ import org.opensearch.action.bulk.BulkItemResponse.Failure; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.index.IndexResponse; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.ml.CheckpointDao; @@ -58,11 +57,12 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; -import org.opensearch.index.Index; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; import org.opensearch.index.engine.VersionConflictEngineException; -import org.opensearch.index.shard.ShardId; -import org.opensearch.rest.RestStatus; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.constant.CommonName; import test.org.opensearch.ad.util.MLUtil; @@ -182,7 +182,7 @@ public void testTriggerAutoFlush() throws InterruptedException { ExecutorService executorService = mock(ExecutorService.class); ThreadPool mockThreadPool = mock(ThreadPool.class); - when(mockThreadPool.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); + when(mockThreadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); doAnswer(invocation -> { Runnable runnable = () -> { try { diff --git a/src/test/java/org/opensearch/ad/ratelimit/CheckpointWriteWorkerTests.java-e b/src/test/java/org/opensearch/ad/ratelimit/CheckpointWriteWorkerTests.java-e new file mode 100644 index 000000000..7fbc6d179 --- /dev/null +++ b/src/test/java/org/opensearch/ad/ratelimit/CheckpointWriteWorkerTests.java-e @@ -0,0 +1,432 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.ConcurrentModificationException; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Random; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkItemResponse.Failure; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.ml.CheckpointDao; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; +import org.opensearch.core.index.Index; +import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.constant.CommonName; + +import test.org.opensearch.ad.util.MLUtil; +import test.org.opensearch.ad.util.RandomModelStateConfig; + +public class CheckpointWriteWorkerTests extends AbstractRateLimitingTest { + CheckpointWriteWorker worker; + + CheckpointDao checkpoint; + ClusterService clusterService; + + ModelState state; + + @Override + @SuppressWarnings("unchecked") + public void setUp() throws Exception { + super.setUp(); + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays + .asList( + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE + ) + ) + ) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + checkpoint = mock(CheckpointDao.class); + Map checkpointMap = new HashMap<>(); + checkpointMap.put(CommonName.FIELD_MODEL, "a"); + when(checkpoint.toIndexSource(any())).thenReturn(checkpointMap); + when(checkpoint.shouldSave(any(), anyBoolean(), any(), any())).thenReturn(true); + + // Integer.MAX_VALUE makes a huge heap + worker = new CheckpointWriteWorker( + Integer.MAX_VALUE, + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES, + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, + clusterService, + new Random(42), + mock(ADCircuitBreakerService.class), + threadPool, + Settings.EMPTY, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + clock, + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + checkpoint, + ADCommonName.CHECKPOINT_INDEX_NAME, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + nodeStateManager, + AnomalyDetectorSettings.HOURLY_MAINTENANCE + ); + + state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().build()); + } + + public void testTriggerSave() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + BulkItemResponse[] responses = new BulkItemResponse[1]; + ShardId shardId = new ShardId(new Index("index_name", "uuid"), 0); + responses[0] = new BulkItemResponse( + 0, + randomFrom(DocWriteRequest.OpType.values()), + new IndexResponse(shardId, "id", 1, 1, 1, true) + ); + listener.onResponse(new BulkResponse(responses, 1)); + + return null; + }).when(checkpoint).batchWrite(any(), any()); + + worker.write(state, true, RequestPriority.MEDIUM); + + verify(checkpoint, times(1)).batchWrite(any(), any()); + } + + public void testTriggerSaveAll() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + BulkItemResponse[] responses = new BulkItemResponse[1]; + ShardId shardId = new ShardId(new Index("index_name", "uuid"), 0); + responses[0] = new BulkItemResponse( + 0, + randomFrom(DocWriteRequest.OpType.values()), + new IndexResponse(shardId, "id", 1, 1, 1, true) + ); + listener.onResponse(new BulkResponse(responses, 1)); + + return null; + }).when(checkpoint).batchWrite(any(), any()); + + List> states = new ArrayList<>(); + states.add(state); + worker.writeAll(states, detectorId, true, RequestPriority.MEDIUM); + + verify(checkpoint, times(1)).batchWrite(any(), any()); + } + + /** + * Test that when more requests are coming than concurrency allowed, queues will be + * auto-flushed given enough time. + * @throws InterruptedException when thread.sleep gets interrupted + */ + public void testTriggerAutoFlush() throws InterruptedException { + final CountDownLatch processingLatch = new CountDownLatch(1); + + ExecutorService executorService = mock(ExecutorService.class); + + ThreadPool mockThreadPool = mock(ThreadPool.class); + when(mockThreadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); + doAnswer(invocation -> { + Runnable runnable = () -> { + try { + processingLatch.await(100, TimeUnit.SECONDS); + } catch (InterruptedException e) { + LOG.error(e); + assertTrue("Unexpected exception", false); + } + Runnable toInvoke = invocation.getArgument(0); + toInvoke.run(); + }; + // start a new thread so it won't block main test thread's execution + new Thread(runnable).start(); + return null; + }).when(executorService).execute(any(Runnable.class)); + + // make sure permits are released and the next request probe starts + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(checkpoint).batchWrite(any(), any()); + + // Integer.MAX_VALUE makes a huge heap + // create a worker to use mockThreadPool + worker = new CheckpointWriteWorker( + Integer.MAX_VALUE, + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES, + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, + clusterService, + new Random(42), + mock(ADCircuitBreakerService.class), + mockThreadPool, + Settings.EMPTY, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + clock, + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + checkpoint, + ADCommonName.CHECKPOINT_INDEX_NAME, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + nodeStateManager, + AnomalyDetectorSettings.HOURLY_MAINTENANCE + ); + + // our concurrency is 2, so first 2 requests cause two batches. And the + // remaining 1 stays in the queue until the 2 concurrent runs finish. + // first 2 batch account for one checkpoint.batchWrite; the remaining one + // calls checkpoint.batchWrite + // CHECKPOINT_WRITE_QUEUE_BATCH_SIZE is the largest batch size + int numberOfRequests = 2 * AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE.getDefault(Settings.EMPTY) + 1; + for (int i = 0; i < numberOfRequests; i++) { + ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().build()); + worker.write(state, true, RequestPriority.MEDIUM); + } + + // Here, we allow the first 2 pulling batch from queue operations to start. + processingLatch.countDown(); + + // wait until queues get emptied + int waitIntervals = 20; + while (!worker.isQueueEmpty() && waitIntervals-- >= 0) { + Thread.sleep(500); + } + + assertTrue(worker.isQueueEmpty()); + // of requests cause at least one batch. + verify(checkpoint, times(3)).batchWrite(any(), any()); + } + + public void testOverloaded() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new OpenSearchRejectedExecutionException("blah", true)); + + return null; + }).when(checkpoint).batchWrite(any(), any()); + + worker.write(state, true, RequestPriority.MEDIUM); + + verify(checkpoint, times(1)).batchWrite(any(), any()); + verify(nodeStateManager, times(1)).setException(eq(state.getId()), any(OpenSearchRejectedExecutionException.class)); + } + + public void testRetryException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new OpenSearchStatusException("blah", RestStatus.REQUEST_TIMEOUT)); + + return null; + }).when(checkpoint).batchWrite(any(), any()); + + worker.write(state, true, RequestPriority.MEDIUM); + // we don't retry checkpoint write + verify(checkpoint, times(1)).batchWrite(any(), any()); + verify(nodeStateManager, times(1)).setException(eq(state.getId()), any(OpenSearchStatusException.class)); + } + + /** + * Test that we don'd retry failed request + */ + public void testFailedRequest() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + BulkItemResponse[] responses = new BulkItemResponse[1]; + ShardId shardId = new ShardId(new Index("index_name", "uuid"), 0); + responses[0] = new BulkItemResponse( + 0, + randomFrom(DocWriteRequest.OpType.values()), + new Failure(shardId.getIndexName(), "id1", new VersionConflictEngineException(shardId, "id1", "blah")) + ); + listener.onResponse(new BulkResponse(responses, 1)); + + return null; + }).when(checkpoint).batchWrite(any(), any()); + + worker.write(state, true, RequestPriority.MEDIUM); + // we don't retry checkpoint write + verify(checkpoint, times(1)).batchWrite(any(), any()); + } + + @SuppressWarnings("unchecked") + public void testEmptyTimeStamp() { + ModelState state = mock(ModelState.class); + when(state.getLastCheckpointTime()).thenReturn(Instant.MIN); + worker.write(state, false, RequestPriority.MEDIUM); + + verify(checkpoint, never()).batchWrite(any(), any()); + } + + @SuppressWarnings("unchecked") + public void testTooSoonToSaveSingleWrite() { + ModelState state = mock(ModelState.class); + when(state.getLastCheckpointTime()).thenReturn(Instant.now()); + worker.write(state, false, RequestPriority.MEDIUM); + + verify(checkpoint, never()).batchWrite(any(), any()); + } + + @SuppressWarnings("unchecked") + public void testTooSoonToSaveWriteAll() { + ModelState state = mock(ModelState.class); + when(state.getLastCheckpointTime()).thenReturn(Instant.now()); + + List> states = new ArrayList<>(); + states.add(state); + + worker.writeAll(states, detectorId, false, RequestPriority.MEDIUM); + + verify(checkpoint, never()).batchWrite(any(), any()); + } + + @SuppressWarnings("unchecked") + public void testEmptyModel() { + ModelState state = mock(ModelState.class); + when(state.getLastCheckpointTime()).thenReturn(Instant.now()); + when(state.getModel()).thenReturn(null); + worker.write(state, true, RequestPriority.MEDIUM); + + verify(checkpoint, never()).batchWrite(any(), any()); + } + + @SuppressWarnings("unchecked") + public void testEmptyModelId() { + ModelState state = mock(ModelState.class); + when(state.getLastCheckpointTime()).thenReturn(Instant.now()); + EntityModel model = mock(EntityModel.class); + when(state.getModel()).thenReturn(model); + when(state.getId()).thenReturn("1"); + when(state.getModelId()).thenReturn(null); + worker.write(state, true, RequestPriority.MEDIUM); + + verify(checkpoint, never()).batchWrite(any(), any()); + } + + @SuppressWarnings("unchecked") + public void testEmptyDetectorId() { + ModelState state = mock(ModelState.class); + when(state.getLastCheckpointTime()).thenReturn(Instant.now()); + EntityModel model = mock(EntityModel.class); + when(state.getModel()).thenReturn(model); + when(state.getId()).thenReturn(null); + when(state.getModelId()).thenReturn("a"); + worker.write(state, true, RequestPriority.MEDIUM); + + verify(checkpoint, never()).batchWrite(any(), any()); + } + + @SuppressWarnings("unchecked") + public void testDetectorNotAvailableSingleWrite() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.empty()); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + + worker.write(state, true, RequestPriority.MEDIUM); + verify(checkpoint, never()).batchWrite(any(), any()); + } + + @SuppressWarnings("unchecked") + public void testDetectorNotAvailableWriteAll() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.empty()); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + + List> states = new ArrayList<>(); + states.add(state); + worker.writeAll(states, detectorId, true, RequestPriority.MEDIUM); + verify(checkpoint, never()).batchWrite(any(), any()); + } + + @SuppressWarnings("unchecked") + public void testDetectorFetchException() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException()); + return null; + }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + + worker.write(state, true, RequestPriority.MEDIUM); + verify(checkpoint, never()).batchWrite(any(), any()); + } + + public void testCheckpointNullSource() throws IOException { + when(checkpoint.toIndexSource(any())).thenReturn(null); + worker.write(state, true, RequestPriority.MEDIUM); + verify(checkpoint, never()).batchWrite(any(), any()); + } + + public void testCheckpointEmptySource() throws IOException { + Map checkpointMap = new HashMap<>(); + when(checkpoint.toIndexSource(any())).thenReturn(checkpointMap); + worker.write(state, true, RequestPriority.MEDIUM); + verify(checkpoint, never()).batchWrite(any(), any()); + } + + public void testConcurrentModificationException() throws IOException { + doThrow(ConcurrentModificationException.class).when(checkpoint).toIndexSource(any()); + worker.write(state, true, RequestPriority.MEDIUM); + verify(checkpoint, never()).batchWrite(any(), any()); + } +} diff --git a/src/test/java/org/opensearch/ad/ratelimit/ColdEntityWorkerTests.java-e b/src/test/java/org/opensearch/ad/ratelimit/ColdEntityWorkerTests.java-e new file mode 100644 index 000000000..f4af298c8 --- /dev/null +++ b/src/test/java/org/opensearch/ad/ratelimit/ColdEntityWorkerTests.java-e @@ -0,0 +1,180 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Random; + +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; + +public class ColdEntityWorkerTests extends AbstractRateLimitingTest { + ClusterService clusterService; + ColdEntityWorker coldWorker; + CheckpointReadWorker readWorker; + EntityFeatureRequest request, request2, invalidRequest; + List requests; + + @Override + public void setUp() throws Exception { + super.setUp(); + clusterService = mock(ClusterService.class); + Settings settings = Settings.builder().put(AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE.getKey(), 1).build(); + ClusterSettings clusterSettings = new ClusterSettings( + settings, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays + .asList( + AnomalyDetectorSettings.EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, + AnomalyDetectorSettings.COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE + ) + ) + ) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + readWorker = mock(CheckpointReadWorker.class); + + // Integer.MAX_VALUE makes a huge heap + coldWorker = new ColdEntityWorker( + Integer.MAX_VALUE, + AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + clusterService, + new Random(42), + mock(ADCircuitBreakerService.class), + threadPool, + settings, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + clock, + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + readWorker, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + nodeStateManager + ); + + request = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.LOW, entity, new double[] { 0 }, 0); + request2 = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.LOW, entity2, new double[] { 0 }, 0); + invalidRequest = new EntityFeatureRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity2, new double[] { 0 }, 0); + + requests = new ArrayList<>(); + requests.add(request); + requests.add(request2); + requests.add(invalidRequest); + + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + + TimeValue value = invocation.getArgument(1); + // since we have only 1 request each time + long expectedExecutionPerRequestMilli = AnomalyDetectorSettings.EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS + .getDefault(Settings.EMPTY); + long delay = value.getMillis(); + assertTrue(delay == expectedExecutionPerRequestMilli); + return null; + }).when(threadPool).schedule(any(), any(), any()); + } + + public void testPutRequests() { + coldWorker.putAll(requests); + + verify(readWorker, times(2)).putAll(any()); + verify(threadPool, times(2)).schedule(any(), any(), any()); + } + + /** + * We will log a line and continue trying despite exception + */ + public void testCheckpointReadPutException() { + doThrow(RuntimeException.class).when(readWorker).putAll(any()); + coldWorker.putAll(requests); + verify(readWorker, times(2)).putAll(any()); + verify(threadPool, never()).schedule(any(), any(), any()); + } + + /** + * First, invalidRequest gets pulled out and we re-pull; Then we have schedule exception. + * Will not schedule others anymore. + */ + public void testScheduleException() { + doThrow(RuntimeException.class).when(threadPool).schedule(any(), any(), any()); + coldWorker.putAll(requests); + verify(readWorker, times(1)).putAll(any()); + verify(threadPool, times(1)).schedule(any(), any(), any()); + } + + public void testDelay() { + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays + .asList( + AnomalyDetectorSettings.EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, + AnomalyDetectorSettings.COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE + ) + ) + ) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + // Integer.MAX_VALUE makes a huge heap + coldWorker = new ColdEntityWorker( + Integer.MAX_VALUE, + AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + clusterService, + new Random(42), + mock(ADCircuitBreakerService.class), + threadPool, + Settings.EMPTY, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + clock, + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + readWorker, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + nodeStateManager + ); + + coldWorker.putAll(requests); + + verify(readWorker, times(1)).putAll(any()); + verify(threadPool, never()).schedule(any(), any(), any()); + } +} diff --git a/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java index cd2a1ac81..5580b5f30 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java @@ -39,7 +39,7 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; -import org.opensearch.rest.RestStatus; +import org.opensearch.core.rest.RestStatus; import test.org.opensearch.ad.util.MLUtil; diff --git a/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java-e b/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java-e new file mode 100644 index 000000000..5580b5f30 --- /dev/null +++ b/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java-e @@ -0,0 +1,165 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Optional; +import java.util.Random; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.ml.EntityColdStarter; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; +import org.opensearch.core.rest.RestStatus; + +import test.org.opensearch.ad.util.MLUtil; + +public class EntityColdStartWorkerTests extends AbstractRateLimitingTest { + ClusterService clusterService; + EntityColdStartWorker worker; + EntityColdStarter entityColdStarter; + CacheProvider cacheProvider; + + @Override + public void setUp() throws Exception { + super.setUp(); + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays + .asList( + AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_CONCURRENCY + ) + ) + ) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + entityColdStarter = mock(EntityColdStarter.class); + + cacheProvider = mock(CacheProvider.class); + + // Integer.MAX_VALUE makes a huge heap + worker = new EntityColdStartWorker( + Integer.MAX_VALUE, + AnomalyDetectorSettings.ENTITY_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, + clusterService, + new Random(42), + mock(ADCircuitBreakerService.class), + threadPool, + Settings.EMPTY, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + clock, + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + entityColdStarter, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + nodeStateManager, + cacheProvider + ); + } + + public void testEmptyModelId() { + EntityRequest request = mock(EntityRequest.class); + when(request.getPriority()).thenReturn(RequestPriority.LOW); + when(request.getModelId()).thenReturn(Optional.empty()); + worker.put(request); + verify(entityColdStarter, never()).trainModel(any(), anyString(), any(), any()); + verify(request, times(1)).getModelId(); + } + + public void testOverloaded() { + EntityRequest request = new EntityRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onFailure(new OpenSearchRejectedExecutionException("blah", true)); + + return null; + }).when(entityColdStarter).trainModel(any(), anyString(), any(), any()); + + worker.put(request); + + verify(entityColdStarter, times(1)).trainModel(any(), anyString(), any(), any()); + verify(nodeStateManager, times(1)).setException(eq(detectorId), any(OpenSearchRejectedExecutionException.class)); + + // 2nd put request won't trigger anything as we are in cooldown mode + worker.put(request); + verify(entityColdStarter, times(1)).trainModel(any(), anyString(), any(), any()); + } + + public void testException() { + EntityRequest request = new EntityRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onFailure(new OpenSearchStatusException("blah", RestStatus.REQUEST_TIMEOUT)); + + return null; + }).when(entityColdStarter).trainModel(any(), anyString(), any(), any()); + + worker.put(request); + + verify(entityColdStarter, times(1)).trainModel(any(), anyString(), any(), any()); + verify(nodeStateManager, times(1)).setException(eq(detectorId), any(OpenSearchStatusException.class)); + + // 2nd put request triggers another setException + worker.put(request); + verify(entityColdStarter, times(2)).trainModel(any(), anyString(), any(), any()); + verify(nodeStateManager, times(2)).setException(eq(detectorId), any(OpenSearchStatusException.class)); + } + + public void testModelHosted() { + EntityRequest request = new EntityRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + + ModelState state = invocation.getArgument(2); + state.setModel(MLUtil.createNonEmptyModel(detectorId)); + listener.onResponse(null); + + return null; + }).when(entityColdStarter).trainModel(any(), anyString(), any(), any()); + + worker.put(request); + + verify(cacheProvider, times(1)).get(); + } +} diff --git a/src/test/java/org/opensearch/ad/ratelimit/ResultWriteWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/ResultWriteWorkerTests.java index 666bc03ed..4b46311c6 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/ResultWriteWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/ResultWriteWorkerTests.java @@ -45,8 +45,8 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.rest.RestStatus; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.util.RestHandlerUtils; diff --git a/src/test/java/org/opensearch/ad/ratelimit/ResultWriteWorkerTests.java-e b/src/test/java/org/opensearch/ad/ratelimit/ResultWriteWorkerTests.java-e new file mode 100644 index 000000000..d8a31ae40 --- /dev/null +++ b/src/test/java/org/opensearch/ad/ratelimit/ResultWriteWorkerTests.java-e @@ -0,0 +1,208 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Random; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.transport.ADResultBulkRequest; +import org.opensearch.ad.transport.ADResultBulkResponse; +import org.opensearch.ad.transport.handler.MultiEntityResultHandler; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.util.RestHandlerUtils; + +public class ResultWriteWorkerTests extends AbstractRateLimitingTest { + ResultWriteWorker resultWriteQueue; + ClusterService clusterService; + MultiEntityResultHandler resultHandler; + AnomalyResult detectResult; + + @Override + public void setUp() throws Exception { + super.setUp(); + + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays + .asList( + AnomalyDetectorSettings.RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_BATCH_SIZE + ) + ) + ) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + threadPool = mock(ThreadPool.class); + setUpADThreadPool(threadPool); + + resultHandler = mock(MultiEntityResultHandler.class); + + resultWriteQueue = new ResultWriteWorker( + Integer.MAX_VALUE, + AnomalyDetectorSettings.RESULT_WRITE_QUEUE_SIZE_IN_BYTES, + AnomalyDetectorSettings.RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, + clusterService, + new Random(42), + mock(ADCircuitBreakerService.class), + threadPool, + Settings.EMPTY, + AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + clock, + AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, + AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + AnomalyDetectorSettings.QUEUE_MAINTENANCE, + resultHandler, + xContentRegistry(), + nodeStateManager, + AnomalyDetectorSettings.HOURLY_MAINTENANCE + ); + + detectResult = TestHelpers.randomHCADAnomalyDetectResult(0.8, Double.NaN, null); + } + + public void testRegular() { + List retryRequests = new ArrayList<>(); + + ADResultBulkResponse resp = new ADResultBulkResponse(retryRequests); + + ADResultBulkRequest request = new ADResultBulkRequest(); + ResultWriteRequest resultWriteRequest = new ResultWriteRequest( + Instant.now().plus(10, ChronoUnit.MINUTES).toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + detectResult, + null + ); + request.add(resultWriteRequest); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(resp); + return null; + }).when(resultHandler).flush(any(), any()); + + resultWriteQueue.put(new ResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult, null)); + + // the request results one flush + verify(resultHandler, times(1)).flush(any(), any()); + } + + public void testSingleRetryRequest() throws IOException { + List retryRequests = new ArrayList<>(); + try (XContentBuilder builder = jsonBuilder()) { + IndexRequest indexRequest = new IndexRequest(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS) + .source(detectResult.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)); + retryRequests.add(indexRequest); + } + + ADResultBulkResponse resp = new ADResultBulkResponse(retryRequests); + + ADResultBulkRequest request = new ADResultBulkRequest(); + ResultWriteRequest resultWriteRequest = new ResultWriteRequest( + Instant.now().plus(10, ChronoUnit.MINUTES).toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + detectResult, + null + ); + request.add(resultWriteRequest); + + final AtomicBoolean retried = new AtomicBoolean(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + if (retried.get()) { + listener.onResponse(new ADResultBulkResponse()); + } else { + retried.set(true); + listener.onResponse(resp); + } + return null; + }).when(resultHandler).flush(any(), any()); + + resultWriteQueue.put(new ResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult, null)); + + // one flush from the original request; and one due to retry + verify(resultHandler, times(2)).flush(any(), any()); + } + + public void testRetryException() { + final AtomicBoolean retried = new AtomicBoolean(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + if (retried.get()) { + listener.onResponse(new ADResultBulkResponse()); + } else { + retried.set(true); + listener.onFailure(new OpenSearchStatusException("blah", RestStatus.REQUEST_TIMEOUT)); + } + + return null; + }).when(resultHandler).flush(any(), any()); + + resultWriteQueue.put(new ResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult, null)); + // one flush from the original request; and one due to retry + verify(resultHandler, times(2)).flush(any(), any()); + verify(nodeStateManager, times(1)).setException(eq(detectorId), any(OpenSearchStatusException.class)); + } + + public void testOverloaded() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new OpenSearchRejectedExecutionException("blah", true)); + + return null; + }).when(resultHandler).flush(any(), any()); + + resultWriteQueue.put(new ResultWriteRequest(Long.MAX_VALUE, detectorId, RequestPriority.MEDIUM, detectResult, null)); + // one flush from the original request; and one due to retry + verify(resultHandler, times(1)).flush(any(), any()); + verify(nodeStateManager, times(1)).setException(eq(detectorId), any(OpenSearchRejectedExecutionException.class)); + } +} diff --git a/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java-e b/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java-e new file mode 100644 index 000000000..fb1ccc1e4 --- /dev/null +++ b/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java-e @@ -0,0 +1,554 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.ad.rest; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.randomBoolean; +import static org.opensearch.test.OpenSearchTestCase.randomAlphaOfLength; +import static org.opensearch.test.OpenSearchTestCase.randomDoubleBetween; +import static org.opensearch.test.OpenSearchTestCase.randomInt; +import static org.opensearch.test.OpenSearchTestCase.randomIntBetween; +import static org.opensearch.test.OpenSearchTestCase.randomLong; +import static org.opensearch.test.rest.OpenSearchRestTestCase.entityAsMap; +import static org.opensearch.timeseries.util.RestHandlerUtils.ANOMALY_DETECTOR_JOB; +import static org.opensearch.timeseries.util.RestHandlerUtils.HISTORICAL_ANALYSIS_TASK; +import static org.opensearch.timeseries.util.RestHandlerUtils.REALTIME_TASK; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.ToDoubleFunction; + +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.ParseException; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.message.BasicHeader; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.core.Logger; +import org.opensearch.ad.mock.model.MockSimpleLog; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskProfile; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.client.Response; +import org.opensearch.client.RestClient; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +//TODO: remove duplicate code in HistoricalAnalysisRestTestCase +public class ADRestTestUtils { + protected static final Logger LOG = (Logger) LogManager.getLogger(ADRestTestUtils.class); + + public enum DetectorType { + SINGLE_ENTITY_DETECTOR, + SINGLE_CATEGORY_HC_DETECTOR, + MULTI_CATEGORY_HC_DETECTOR + } + + public static Response ingestSimpleMockLog( + RestClient client, + String indexName, + int startDays, + int totalDocsPerCategory, + long intervalInMinutes, + ToDoubleFunction valueFunc, + int ipSize, + int categorySize, + boolean createIndex + ) throws IOException { + if (createIndex) { + TestHelpers + .makeRequest( + client, + "PUT", + indexName, + null, + TestHelpers.toHttpEntity(MockSimpleLog.INDEX_MAPPING), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "OpenSearch")) + ); + } + + StringBuilder bulkRequestBuilder = new StringBuilder(); + Instant startTime = Instant.now().minus(startDays, ChronoUnit.DAYS); + for (int i = 0; i < totalDocsPerCategory; i++) { + for (int m = 0; m < ipSize; m++) { + String ip = "192.168.1." + m; + for (int n = 0; n < categorySize; n++) { + String category = "category" + n; + String docId = randomAlphaOfLength(10); + bulkRequestBuilder.append("{ \"index\" : { \"_index\" : \"" + indexName + "\", \"_id\" : \"" + docId + "\" } }\n"); + MockSimpleLog simpleLog1 = new MockSimpleLog( + startTime, + valueFunc.applyAsDouble(i), + ip, + category, + randomBoolean(), + randomAlphaOfLength(5) + ); + bulkRequestBuilder.append(TestHelpers.toJsonString(simpleLog1)); + bulkRequestBuilder.append("\n"); + } + } + startTime = startTime.plus(intervalInMinutes, ChronoUnit.MINUTES); + } + Response bulkResponse = TestHelpers + .makeRequest( + client, + "POST", + "_bulk?refresh=true", + null, + TestHelpers.toHttpEntity(bulkRequestBuilder.toString()), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + return bulkResponse; + } + + public static Response ingestTestDataForHistoricalAnalysis( + RestClient client, + String indexName, + int detectionIntervalInMinutes, + boolean createIndex, + int startDays, + int totalDocsPerCategory, + int categoryFieldSize + ) throws IOException { + return ingestSimpleMockLog(client, indexName, startDays, totalDocsPerCategory, detectionIntervalInMinutes, (i) -> { + if (i % 500 == 0) { + return randomDoubleBetween(100, 1000, true); + } else { + return randomDoubleBetween(1, 10, true); + } + }, categoryFieldSize, categoryFieldSize, createIndex); + } + + @SuppressWarnings("unchecked") + public static int getDocCountOfIndex(RestClient client, String indexName) throws IOException { + Response searchResponse = TestHelpers + .makeRequest( + client, + "GET", + indexName + "/_search", + null, + TestHelpers.toHttpEntity("{\"track_total_hits\": true}"), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "OpenSearch")) + ); + + Map responseMap = entityAsMap(searchResponse); + Object total = ((Map) responseMap.get("hits")).get("total"); + return (int) ((Map) total).get("value"); + } + + public static Response createAnomalyDetector( + RestClient client, + String indexName, + String timeField, + int detectionIntervalInMinutes, + int windowDelayIntervalInMinutes, + String valueField, + String aggregationMethod, + String filterQuery, + List categoryFields + ) throws Exception { + return createAnomalyDetector( + client, + indexName, + timeField, + detectionIntervalInMinutes, + windowDelayIntervalInMinutes, + valueField, + aggregationMethod, + filterQuery, + categoryFields, + false + ); + } + + public static Response createAnomalyDetector( + RestClient client, + String indexName, + String timeField, + int detectionIntervalInMinutes, + int windowDelayIntervalInMinutes, + String valueField, + String aggregationMethod, + String filterQuery, + List categoryFields, + boolean historical + ) throws Exception { + Instant now = Instant.now(); + AnomalyDetector detector = new AnomalyDetector( + randomAlphaOfLength(10), + randomLong(), + // TODO: check why throw duplicate detector name error with randomAlphaOfLength(20) in twoThirdsUpgradedClusterTask + randomAlphaOfLength(20) + now.toEpochMilli(), + randomAlphaOfLength(30), + timeField, + ImmutableList.of(indexName), + ImmutableList.of(TestHelpers.randomFeature(randomAlphaOfLength(5), valueField, aggregationMethod, true)), + filterQuery == null ? TestHelpers.randomQuery("{\"match_all\":{\"boost\":1}}") : TestHelpers.randomQuery(filterQuery), + new IntervalTimeConfiguration(detectionIntervalInMinutes, ChronoUnit.MINUTES), + new IntervalTimeConfiguration(windowDelayIntervalInMinutes, ChronoUnit.MINUTES), + randomIntBetween(1, 20), + null, + randomInt(), + now, + categoryFields, + TestHelpers.randomUser(), + null, + TestHelpers.randomImputationOption() + ); + + if (historical) { + detector.setDetectionDateRange(new DateRange(now.minus(30, ChronoUnit.DAYS), now)); + } + + return TestHelpers + .makeRequest( + client, + "POST", + TestHelpers.LEGACY_OPENDISTRO_AD_BASE_DETECTORS_URI, + ImmutableMap.of(), + TestHelpers.toHttpEntity(detector), + null + ); + } + + @SuppressWarnings("unchecked") + public static List searchLatestAdTaskOfDetector(RestClient client, String detectorId, String taskType) throws IOException { + List adTasks = new ArrayList<>(); + Response searchAdTaskResponse = TestHelpers + .makeRequest( + client, + "POST", + TestHelpers.LEGACY_OPENDISTRO_AD_BASE_DETECTORS_URI + "/tasks/_search", + ImmutableMap.of(), + TestHelpers + .toHttpEntity( + "{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"detector_id\":\"" + + detectorId + + "\"}},{\"term\":{\"is_latest\":\"true\"}},{\"terms\":{\"task_type\":[\"" + + taskType + + "\"]}}]}},\"sort\":[{\"execution_start_time\":{\"order\":\"desc\"}}],\"size\":1000}" + ), + null + ); + Map responseMap = entityAsMap(searchAdTaskResponse); + Map hits = (Map) responseMap.get("hits"); + Object totalHits = hits.get("total"); + Integer totalTasks = (Integer) ((Map) totalHits).get("value"); + + if (totalTasks == 0) { + return adTasks; + } + List adTaskResponses = (List) hits.get("hits"); + for (Object adTaskResponse : adTaskResponses) { + String id = (String) ((Map) adTaskResponse).get("_id"); + Map source = (Map) ((Map) adTaskResponse).get("_source"); + String state = (String) source.get(ADTask.STATE_FIELD); + String parsedDetectorId = (String) source.get(ADTask.DETECTOR_ID_FIELD); + Double taskProgress = (Double) source.get(ADTask.TASK_PROGRESS_FIELD); + Double initProgress = (Double) source.get(ADTask.INIT_PROGRESS_FIELD); + String parsedTaskType = (String) source.get(ADTask.TASK_TYPE_FIELD); + String coordinatingNode = (String) source.get(ADTask.COORDINATING_NODE_FIELD); + ADTask adTask = ADTask + .builder() + .taskId(id) + .state(state) + .detectorId(parsedDetectorId) + .taskProgress(taskProgress.floatValue()) + .initProgress(initProgress.floatValue()) + .taskType(parsedTaskType) + .coordinatingNode(coordinatingNode) + .build(); + adTasks.add(adTask); + } + return adTasks; + } + + @SuppressWarnings("unchecked") + public static int countADResultOfDetector(RestClient client, String detectorId, String taskId) throws IOException { + String taskFilter = "TASK_FILTER"; + String query = "{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"detector_id\":\"" + + detectorId + + "\"}}" + + taskFilter + + "]}},\"track_total_hits\":true,\"size\":0}"; + if (taskId != null) { + query = query.replace(taskFilter, ",{\"term\":{\"task_id\":\"" + taskId + "\"}}"); + } else { + query = query.replace(taskFilter, ""); + } + Response searchAdTaskResponse = TestHelpers + .makeRequest( + client, + "GET", + TestHelpers.LEGACY_OPENDISTRO_AD_BASE_DETECTORS_URI + "/results/_search", + ImmutableMap.of(), + TestHelpers.toHttpEntity(query), + + null + ); + Map responseMap = entityAsMap(searchAdTaskResponse); + Map hits = (Map) ((Map) responseMap.get("hits")).get("total"); + return (int) hits.get("value"); + } + + @SuppressWarnings("unchecked") + public static int countDetectors(RestClient client, String detectorType) throws IOException { + String detectorTypeFilter = "DETECTOR_TYPE_FILTER"; + String query = "{\"query\":{\"bool\":{\"filter\":[{\"exists\":{\"field\":\"name\"}}" + + detectorTypeFilter + + "]}},\"track_total_hits\":true,\"size\":0}"; + if (detectorType != null) { + query = query.replace(detectorTypeFilter, ",{\"term\":{\"detector_type\":\"" + detectorType + "\"}}"); + } else { + query = query.replace(detectorTypeFilter, ""); + } + Response searchAdTaskResponse = TestHelpers + .makeRequest( + client, + "GET", + TestHelpers.LEGACY_OPENDISTRO_AD_BASE_DETECTORS_URI + "/_search", + ImmutableMap.of(), + TestHelpers.toHttpEntity(query), + + null + ); + Map responseMap = entityAsMap(searchAdTaskResponse); + Map hits = (Map) ((Map) responseMap.get("hits")).get("total"); + return (int) hits.get("value"); + } + + @SuppressWarnings("unchecked") + public static Map getDetectorWithJobAndTask(RestClient client, String detectorId) throws IOException { + Map results = new HashMap<>(); + Response searchAdTaskResponse = TestHelpers + .makeRequest( + client, + "GET", + TestHelpers.LEGACY_OPENDISTRO_AD_BASE_DETECTORS_URI + "/" + detectorId + "?job=true&task=true", + ImmutableMap.of(), + (HttpEntity) null, + null + ); + Map responseMap = entityAsMap(searchAdTaskResponse); + + Map jobMap = (Map) responseMap.get(ANOMALY_DETECTOR_JOB); + if (jobMap != null) { + String jobName = (String) jobMap.get(AnomalyDetectorJob.NAME_FIELD); + boolean enabled = (boolean) jobMap.get(AnomalyDetectorJob.IS_ENABLED_FIELD); + long enabledTime = (long) jobMap.get(AnomalyDetectorJob.ENABLED_TIME_FIELD); + long lastUpdateTime = (long) jobMap.get(AnomalyDetectorJob.LAST_UPDATE_TIME_FIELD); + + AnomalyDetectorJob job = new AnomalyDetectorJob( + jobName, + null, + null, + enabled, + Instant.ofEpochMilli(enabledTime), + null, + Instant.ofEpochMilli(lastUpdateTime), + null, + null, + null + ); + results.put(ANOMALY_DETECTOR_JOB, job); + } + + Map historicalTaskMap = (Map) responseMap.get(HISTORICAL_ANALYSIS_TASK); + if (historicalTaskMap != null) { + ADTask historicalAdTask = parseAdTask(historicalTaskMap); + results.put(HISTORICAL_ANALYSIS_TASK, historicalAdTask); + } + + Map realtimeTaskMap = (Map) responseMap.get(REALTIME_TASK); + if (realtimeTaskMap != null) { + ADTask realtimeAdTask = parseAdTask(realtimeTaskMap); + results.put(REALTIME_TASK, realtimeAdTask); + } + + return results; + } + + private static ADTask parseAdTask(Map taskMap) { + String id = (String) taskMap.get(ADTask.TASK_ID_FIELD); + String state = (String) taskMap.get(ADTask.STATE_FIELD); + String parsedDetectorId = (String) taskMap.get(ADTask.DETECTOR_ID_FIELD); + Double taskProgress = (Double) taskMap.get(ADTask.TASK_PROGRESS_FIELD); + Double initProgress = (Double) taskMap.get(ADTask.INIT_PROGRESS_FIELD); + String parsedTaskType = (String) taskMap.get(ADTask.TASK_TYPE_FIELD); + String coordinatingNode = (String) taskMap.get(ADTask.COORDINATING_NODE_FIELD); + return ADTask + .builder() + .taskId(id) + .state(state) + .detectorId(parsedDetectorId) + .taskProgress(taskProgress.floatValue()) + .initProgress(initProgress.floatValue()) + .taskType(parsedTaskType) + .coordinatingNode(coordinatingNode) + .build(); + } + + /** + * Start anomaly detector directly. + * For AD versions on or before 1.0, this function will start realtime job for + * realtime detector, and start historical analysis for historical detector. + * + * For AD version on or after 1.1, this function will start realtime job only. + * @param client REST client + * @param detectorId detector id + * @return job id for realtime job or task id for historical analysis + * @throws IOException exception may throw in entityAsMap + */ + @SuppressWarnings("unchecked") + public static String startAnomalyDetectorDirectly(RestClient client, String detectorId) throws IOException { + Response response = TestHelpers + .makeRequest( + client, + "POST", + TestHelpers.LEGACY_OPENDISTRO_AD_BASE_DETECTORS_URI + "/" + detectorId + "/_start", + ImmutableMap.of(), + (HttpEntity) null, + null + ); + Map startDetectorResponseMap = entityAsMap(response); + // For AD on or before 1.0, if the detector is historical detector, then it will be task id + String jobOrTaskId = (String) startDetectorResponseMap.get("_id"); + return jobOrTaskId; + } + + /** + * Start historical analysis. + * For AD versions on or before 1.0, should pass historical detector id to + * this function. + * For AD version on or after 1.1, can pass any detector id to this function. + * + * @param client REST client + * @param detectorId detector id + * @return task id of historical analysis + * @throws IOException exception may throw in toHttpEntity and entityAsMap + */ + @SuppressWarnings("unchecked") + public static String startHistoricalAnalysis(RestClient client, String detectorId) throws IOException { + Instant now = Instant.now(); + DateRange dateRange = new DateRange(now.minus(30, ChronoUnit.DAYS), now); + Response response = TestHelpers + .makeRequest( + client, + "POST", + TestHelpers.LEGACY_OPENDISTRO_AD_BASE_DETECTORS_URI + "/" + detectorId + "/_start", + ImmutableMap.of(), + // Start historical detector directly on new node will start realtime job. + // Need to pass detection date range in http body if need to start historical analysis. + TestHelpers.toHttpEntity(TestHelpers.toJsonString(dateRange)), + null + ); + Map startDetectorResponseMap = entityAsMap(response); + String taskId = (String) startDetectorResponseMap.get("_id"); + return taskId; + } + + public static ADTaskProfile waitUntilTaskDone(RestClient client, String detectorId) throws InterruptedException { + return waitUntilTaskReachState(client, detectorId, TestHelpers.HISTORICAL_ANALYSIS_DONE_STATS); + } + + public static ADTaskProfile waitUntilTaskReachState(RestClient client, String detectorId, Set targetStates) + throws InterruptedException { + int i = 0; + int retryTimes = 200; + ADTaskProfile adTaskProfile = null; + while ((adTaskProfile == null || !targetStates.contains(adTaskProfile.getAdTask().getState())) && i < retryTimes) { + try { + adTaskProfile = getADTaskProfile(client, detectorId); + } catch (Exception e) { + LOG.error("failed to get ADTaskProfile", e); + } finally { + Thread.sleep(1000); + } + i++; + } + // assertNotNull(adTaskProfile); + return adTaskProfile; + } + + public static ADTaskProfile getADTaskProfile(RestClient client, String detectorId) throws IOException, ParseException { + Response profileResponse = TestHelpers + .makeRequest( + client, + "GET", + TestHelpers.LEGACY_OPENDISTRO_AD_BASE_DETECTORS_URI + "/" + detectorId + "/_profile?_all", + ImmutableMap.of(), + "", + null + ); + return parseADTaskProfile(profileResponse); + } + + public static ADTaskProfile parseADTaskProfile(Response profileResponse) throws IOException, ParseException { + String profileResult = EntityUtils.toString(profileResponse.getEntity()); + XContentParser parser = TestHelpers.parser(profileResult); + ADTaskProfile adTaskProfile = null; + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + if ("ad_task".equals(fieldName)) { + adTaskProfile = ADTaskProfile.parse(parser); + } else { + parser.skipChildren(); + } + } + return adTaskProfile; + } + + public static Response stopRealtimeJob(RestClient client, String detectorId) throws IOException { + return stopDetector(client, detectorId, false); + } + + public static Response stopHistoricalAnalysis(RestClient client, String detectorId) throws IOException { + return stopDetector(client, detectorId, true); + } + + public static Response stopDetector(RestClient client, String detectorId, boolean historicalAnalysis) throws IOException { + String param = historicalAnalysis ? "?historical" : ""; + Response response = TestHelpers + .makeRequest( + client, + "POST", + TestHelpers.LEGACY_OPENDISTRO_AD_BASE_DETECTORS_URI + "/" + detectorId + "/_stop" + param, + ImmutableMap.of(), + "", + null + ); + return response; + } + + public static Response deleteDetector(RestClient client, String detectorId) throws IOException { + Response response = TestHelpers + .makeRequest( + client, + "DELETE", + TestHelpers.LEGACY_OPENDISTRO_AD_BASE_DETECTORS_URI + "/" + detectorId, + ImmutableMap.of(), + "", + null + ); + return response; + } +} diff --git a/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java b/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java index f520439a3..390e68ef7 100644 --- a/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java +++ b/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java @@ -31,7 +31,6 @@ import org.apache.hc.core5.http.io.entity.StringEntity; import org.hamcrest.CoreMatchers; import org.junit.Assert; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.AnomalyDetectorRestTestCase; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; @@ -45,11 +44,12 @@ import org.opensearch.client.ResponseException; import org.opensearch.common.UUIDs; import org.opensearch.common.xcontent.support.XContentMapValues; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.index.query.QueryBuilders; -import org.opensearch.rest.RestStatus; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.DateRange; @@ -481,14 +481,14 @@ public void testStatsAnomalyDetector() throws Exception { updateClusterSettings(ADEnabledSetting.AD_ENABLED, false); Exception ex = expectThrows( ResponseException.class, - () -> TestHelpers.makeRequest(client(), "GET", AnomalyDetectorPlugin.LEGACY_AD_BASE + "/stats", ImmutableMap.of(), "", null) + () -> TestHelpers.makeRequest(client(), "GET", TimeSeriesAnalyticsPlugin.LEGACY_AD_BASE + "/stats", ImmutableMap.of(), "", null) ); assertThat(ex.getMessage(), containsString(ADCommonMessages.DISABLED_ERR_MSG)); updateClusterSettings(ADEnabledSetting.AD_ENABLED, true); Response statsResponse = TestHelpers - .makeRequest(client(), "GET", AnomalyDetectorPlugin.LEGACY_AD_BASE + "/stats", ImmutableMap.of(), "", null); + .makeRequest(client(), "GET", TimeSeriesAnalyticsPlugin.LEGACY_AD_BASE + "/stats", ImmutableMap.of(), "", null); assertEquals("Get stats failed", RestStatus.OK, TestHelpers.restStatus(statsResponse)); } diff --git a/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java-e b/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java-e new file mode 100644 index 000000000..e71ff885f --- /dev/null +++ b/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java-e @@ -0,0 +1,1858 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest; + +import static org.hamcrest.Matchers.containsString; +import static org.opensearch.ad.rest.handler.AbstractAnomalyDetectorActionHandler.DUPLICATE_DETECTOR_MSG; +import static org.opensearch.ad.rest.handler.AbstractAnomalyDetectorActionHandler.NO_DOCS_IN_USER_INDEX_MSG; +import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.io.entity.StringEntity; +import org.hamcrest.CoreMatchers; +import org.junit.Assert; +import org.opensearch.ad.AnomalyDetectorRestTestCase; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorExecutionInput; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.rest.handler.AbstractAnomalyDetectorActionHandler; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; +import org.opensearch.common.UUIDs; +import org.opensearch.common.xcontent.support.XContentMapValues; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +public class AnomalyDetectorRestApiIT extends AnomalyDetectorRestTestCase { + + protected static final String INDEX_NAME = "indexname"; + protected static final String TIME_FIELD = "timestamp"; + + public void testCreateAnomalyDetectorWithNotExistingIndices() throws Exception { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null); + TestHelpers + .assertFailWith( + ResponseException.class, + "index_not_found_exception", + () -> TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI, + ImmutableMap.of(), + TestHelpers.toHttpEntity(detector), + null + ) + ); + } + + public void testCreateAnomalyDetectorWithEmptyIndices() throws Exception { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), null); + TestHelpers + .makeRequest( + client(), + "PUT", + "/" + detector.getIndices().get(0), + ImmutableMap.of(), + TestHelpers + .toHttpEntity( + "{\"settings\":{\"number_of_shards\":1}," + " \"mappings\":{\"properties\":" + "{\"field1\":{\"type\":\"text\"}}}}" + ), + null + ); + + TestHelpers + .assertFailWith( + ResponseException.class, + "Can't create anomaly detector as no document is found in the indices", + () -> TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI, + ImmutableMap.of(), + TestHelpers.toHttpEntity(detector), + null + ) + ); + } + + private AnomalyDetector createIndexAndGetAnomalyDetector(String indexName) throws IOException { + return createIndexAndGetAnomalyDetector(indexName, ImmutableList.of(TestHelpers.randomFeature(true))); + } + + private AnomalyDetector createIndexAndGetAnomalyDetector(String indexName, List features) throws IOException { + TestHelpers.createIndexWithTimeField(client(), indexName, TIME_FIELD); + String testIndexData = "{\"keyword-field\": \"field-1\", \"ip-field\": \"1.2.3.4\", \"timestamp\": 1}"; + TestHelpers.ingestDataToIndex(client(), indexName, TestHelpers.toHttpEntity(testIndexData)); + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(TIME_FIELD, indexName, features); + return detector; + } + + public void testCreateAnomalyDetectorWithDuplicateName() throws Exception { + AnomalyDetector detector = createIndexAndGetAnomalyDetector(INDEX_NAME); + AnomalyDetector detectorDuplicateName = new AnomalyDetector( + AnomalyDetector.NO_ID, + randomLong(), + detector.getName(), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + detector.getIndices(), + ImmutableList.of(TestHelpers.randomFeature()), + TestHelpers.randomQuery(), + TestHelpers.randomIntervalTimeConfiguration(), + TestHelpers.randomIntervalTimeConfiguration(), + randomIntBetween(1, TimeSeriesSettings.MAX_SHINGLE_SIZE), + TestHelpers.randomUiMetadata(), + randomInt(), + null, + null, + TestHelpers.randomUser(), + null, + TestHelpers.randomImputationOption() + ); + + TestHelpers + .assertFailWith( + ResponseException.class, + "Cannot create anomaly detector with name", + () -> TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI, + ImmutableMap.of(), + TestHelpers.toHttpEntity(detectorDuplicateName), + null + ) + ); + } + + public void testCreateAnomalyDetector() throws Exception { + AnomalyDetector detector = createIndexAndGetAnomalyDetector(INDEX_NAME); + updateClusterSettings(ADEnabledSetting.AD_ENABLED, false); + + Exception ex = expectThrows( + ResponseException.class, + () -> TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI, + ImmutableMap.of(), + TestHelpers.toHttpEntity(detector), + null + ) + ); + assertThat(ex.getMessage(), containsString(ADCommonMessages.DISABLED_ERR_MSG)); + + updateClusterSettings(ADEnabledSetting.AD_ENABLED, true); + Response response = TestHelpers + .makeRequest(client(), "POST", TestHelpers.AD_BASE_DETECTORS_URI, ImmutableMap.of(), TestHelpers.toHttpEntity(detector), null); + assertEquals("Create anomaly detector failed", RestStatus.CREATED, TestHelpers.restStatus(response)); + Map responseMap = entityAsMap(response); + String id = (String) responseMap.get("_id"); + int version = (int) responseMap.get("_version"); + assertNotEquals("response is missing Id", AnomalyDetector.NO_ID, id); + assertTrue("incorrect version", version > 0); + } + + public void testUpdateAnomalyDetectorCategoryField() throws Exception { + AnomalyDetector detector = createIndexAndGetAnomalyDetector(INDEX_NAME); + Response response = TestHelpers + .makeRequest(client(), "POST", TestHelpers.AD_BASE_DETECTORS_URI, ImmutableMap.of(), TestHelpers.toHttpEntity(detector), null); + assertEquals("Create anomaly detector failed", RestStatus.CREATED, TestHelpers.restStatus(response)); + Map responseMap = entityAsMap(response); + String id = (String) responseMap.get("_id"); + AnomalyDetector newDetector = new AnomalyDetector( + id, + null, + detector.getName(), + detector.getDescription(), + detector.getTimeField(), + detector.getIndices(), + detector.getFeatureAttributes(), + detector.getFilterQuery(), + detector.getInterval(), + detector.getWindowDelay(), + detector.getShingleSize(), + detector.getUiMetadata(), + detector.getSchemaVersion(), + detector.getLastUpdateTime(), + ImmutableList.of(randomAlphaOfLength(5)), + detector.getUser(), + null, + TestHelpers.randomImputationOption() + ); + Exception ex = expectThrows( + ResponseException.class, + () -> TestHelpers + .makeRequest( + client(), + "PUT", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + id + "?refresh=true", + ImmutableMap.of(), + TestHelpers.toHttpEntity(newDetector), + null + ) + ); + assertThat(ex.getMessage(), containsString(CommonMessages.CAN_NOT_CHANGE_CATEGORY_FIELD)); + } + + public void testGetAnomalyDetector() throws Exception { + AnomalyDetector detector = createRandomAnomalyDetector(true, true, client()); + + updateClusterSettings(ADEnabledSetting.AD_ENABLED, false); + + Exception ex = expectThrows(ResponseException.class, () -> getAnomalyDetector(detector.getId(), client())); + assertThat(ex.getMessage(), containsString(ADCommonMessages.DISABLED_ERR_MSG)); + + updateClusterSettings(ADEnabledSetting.AD_ENABLED, true); + + AnomalyDetector createdDetector = getAnomalyDetector(detector.getId(), client()); + assertEquals("Incorrect Location header", detector, createdDetector); + } + + public void testGetNotExistingAnomalyDetector() throws Exception { + createRandomAnomalyDetector(true, true, client()); + TestHelpers.assertFailWith(ResponseException.class, null, () -> getAnomalyDetector(randomAlphaOfLength(5), client())); + } + + public void testUpdateAnomalyDetector() throws Exception { + AnomalyDetector detector = createAnomalyDetector(createIndexAndGetAnomalyDetector(INDEX_NAME), true, client()); + String newDescription = randomAlphaOfLength(5); + AnomalyDetector newDetector = new AnomalyDetector( + detector.getId(), + detector.getVersion(), + detector.getName(), + newDescription, + detector.getTimeField(), + detector.getIndices(), + detector.getFeatureAttributes(), + detector.getFilterQuery(), + detector.getInterval(), + detector.getWindowDelay(), + detector.getShingleSize(), + detector.getUiMetadata(), + detector.getSchemaVersion(), + detector.getLastUpdateTime(), + null, + detector.getUser(), + null, + TestHelpers.randomImputationOption() + ); + + updateClusterSettings(ADEnabledSetting.AD_ENABLED, false); + + Exception ex = expectThrows( + ResponseException.class, + () -> TestHelpers + .makeRequest( + client(), + "PUT", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId() + "?refresh=true", + ImmutableMap.of(), + TestHelpers.toHttpEntity(newDetector), + null + ) + ); + assertThat(ex.getMessage(), containsString(ADCommonMessages.DISABLED_ERR_MSG)); + + updateClusterSettings(ADEnabledSetting.AD_ENABLED, true); + + Response updateResponse = TestHelpers + .makeRequest( + client(), + "PUT", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId() + "?refresh=true", + ImmutableMap.of(), + TestHelpers.toHttpEntity(newDetector), + null + ); + + assertEquals("Update anomaly detector failed", RestStatus.OK, TestHelpers.restStatus(updateResponse)); + Map responseBody = entityAsMap(updateResponse); + assertEquals("Updated anomaly detector id doesn't match", detector.getId(), responseBody.get("_id")); + assertEquals("Version not incremented", (detector.getVersion().intValue() + 1), (int) responseBody.get("_version")); + + AnomalyDetector updatedDetector = getAnomalyDetector(detector.getId(), client()); + assertNotEquals("Anomaly detector last update time not changed", updatedDetector.getLastUpdateTime(), detector.getLastUpdateTime()); + assertEquals("Anomaly detector description not updated", newDescription, updatedDetector.getDescription()); + } + + public void testUpdateAnomalyDetectorNameToExisting() throws Exception { + AnomalyDetector detector1 = createIndexAndGetAnomalyDetector("index-test-one"); + AnomalyDetector detector2 = createIndexAndGetAnomalyDetector("index-test-two"); + AnomalyDetector newDetector1WithDetector2Name = new AnomalyDetector( + detector1.getId(), + detector1.getVersion(), + detector2.getName(), + detector1.getDescription(), + detector1.getTimeField(), + detector1.getIndices(), + detector1.getFeatureAttributes(), + detector1.getFilterQuery(), + detector1.getInterval(), + detector1.getWindowDelay(), + detector1.getShingleSize(), + detector1.getUiMetadata(), + detector1.getSchemaVersion(), + detector1.getLastUpdateTime(), + null, + detector1.getUser(), + null, + TestHelpers.randomImputationOption() + ); + + TestHelpers + .assertFailWith( + ResponseException.class, + "Cannot create anomaly detector with name", + () -> TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI, + ImmutableMap.of(), + TestHelpers.toHttpEntity(newDetector1WithDetector2Name), + null + ) + ); + } + + public void testUpdateAnomalyDetectorNameToNew() throws Exception { + AnomalyDetector detector = createAnomalyDetector(createIndexAndGetAnomalyDetector(INDEX_NAME), true, client()); + AnomalyDetector detectorWithNewName = new AnomalyDetector( + detector.getId(), + detector.getVersion(), + randomAlphaOfLength(5), + detector.getDescription(), + detector.getTimeField(), + detector.getIndices(), + detector.getFeatureAttributes(), + detector.getFilterQuery(), + detector.getInterval(), + detector.getWindowDelay(), + detector.getShingleSize(), + detector.getUiMetadata(), + detector.getSchemaVersion(), + Instant.now(), + null, + detector.getUser(), + null, + TestHelpers.randomImputationOption() + ); + + TestHelpers + .makeRequest( + client(), + "PUT", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId() + "?refresh=true", + ImmutableMap.of(), + TestHelpers.toHttpEntity(detectorWithNewName), + null + ); + + AnomalyDetector resultDetector = getAnomalyDetector(detectorWithNewName.getId(), client()); + assertEquals("Detector name updating failed", detectorWithNewName.getName(), resultDetector.getName()); + assertEquals("Updated anomaly detector id doesn't match", detectorWithNewName.getId(), resultDetector.getId()); + assertNotEquals( + "Anomaly detector last update time not changed", + detectorWithNewName.getLastUpdateTime(), + resultDetector.getLastUpdateTime() + ); + } + + public void testUpdateAnomalyDetectorWithNotExistingIndex() throws Exception { + AnomalyDetector detector = createRandomAnomalyDetector(true, true, client()); + + String newDescription = randomAlphaOfLength(5); + + AnomalyDetector newDetector = new AnomalyDetector( + detector.getId(), + detector.getVersion(), + detector.getName(), + newDescription, + detector.getTimeField(), + detector.getIndices(), + detector.getFeatureAttributes(), + detector.getFilterQuery(), + detector.getInterval(), + detector.getWindowDelay(), + detector.getShingleSize(), + detector.getUiMetadata(), + detector.getSchemaVersion(), + detector.getLastUpdateTime(), + null, + detector.getUser(), + null, + TestHelpers.randomImputationOption() + ); + + deleteIndexWithAdminClient(CommonName.CONFIG_INDEX); + + TestHelpers + .assertFailWith( + ResponseException.class, + null, + () -> TestHelpers + .makeRequest( + client(), + "PUT", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId(), + ImmutableMap.of(), + TestHelpers.toHttpEntity(newDetector), + null + ) + ); + } + + public void testSearchAnomalyDetector() throws Exception { + AnomalyDetector detector = createRandomAnomalyDetector(true, true, client()); + SearchSourceBuilder search = (new SearchSourceBuilder()).query(QueryBuilders.termQuery("_id", detector.getId())); + + updateClusterSettings(ADEnabledSetting.AD_ENABLED, false); + + Exception ex = expectThrows( + ResponseException.class, + () -> TestHelpers + .makeRequest( + client(), + "GET", + TestHelpers.AD_BASE_DETECTORS_URI + "/_search", + ImmutableMap.of(), + new StringEntity(search.toString(), ContentType.APPLICATION_JSON), + null + ) + ); + assertThat(ex.getMessage(), containsString(ADCommonMessages.DISABLED_ERR_MSG)); + + updateClusterSettings(ADEnabledSetting.AD_ENABLED, true); + + Response searchResponse = TestHelpers + .makeRequest( + client(), + "GET", + TestHelpers.AD_BASE_DETECTORS_URI + "/_search", + ImmutableMap.of(), + new StringEntity(search.toString(), ContentType.APPLICATION_JSON), + null + ); + assertEquals("Search anomaly detector failed", RestStatus.OK, TestHelpers.restStatus(searchResponse)); + } + + public void testStatsAnomalyDetector() throws Exception { + updateClusterSettings(ADEnabledSetting.AD_ENABLED, false); + Exception ex = expectThrows( + ResponseException.class, + () -> TestHelpers.makeRequest(client(), "GET", TimeSeriesAnalyticsPlugin.LEGACY_AD_BASE + "/stats", ImmutableMap.of(), "", null) + ); + assertThat(ex.getMessage(), containsString(ADCommonMessages.DISABLED_ERR_MSG)); + + updateClusterSettings(ADEnabledSetting.AD_ENABLED, true); + + Response statsResponse = TestHelpers + .makeRequest(client(), "GET", TimeSeriesAnalyticsPlugin.LEGACY_AD_BASE + "/stats", ImmutableMap.of(), "", null); + + assertEquals("Get stats failed", RestStatus.OK, TestHelpers.restStatus(statsResponse)); + } + + public void testPreviewAnomalyDetector() throws Exception { + AnomalyDetector detector = createRandomAnomalyDetector(true, false, client()); + AnomalyDetectorExecutionInput input = new AnomalyDetectorExecutionInput( + detector.getId(), + Instant.now().minusSeconds(60 * 10), + Instant.now(), + null + ); + + updateClusterSettings(ADEnabledSetting.AD_ENABLED, false); + + Exception ex = expectThrows( + ResponseException.class, + () -> TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TestHelpers.AD_BASE_PREVIEW_URI, input.getDetectorId()), + ImmutableMap.of(), + TestHelpers.toHttpEntity(input), + null + ) + ); + assertThat(ex.getMessage(), containsString(ADCommonMessages.DISABLED_ERR_MSG)); + + updateClusterSettings(ADEnabledSetting.AD_ENABLED, true); + + Response response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TestHelpers.AD_BASE_PREVIEW_URI, input.getDetectorId()), + ImmutableMap.of(), + TestHelpers.toHttpEntity(input), + null + ); + assertEquals("Execute anomaly detector failed", RestStatus.OK, TestHelpers.restStatus(response)); + } + + public void testPreviewAnomalyDetectorWhichNotExist() throws Exception { + createRandomAnomalyDetector(true, false, client()); + AnomalyDetectorExecutionInput input = new AnomalyDetectorExecutionInput( + randomAlphaOfLength(5), + Instant.now().minusSeconds(60 * 10), + Instant.now(), + null + ); + TestHelpers + .assertFailWith( + ResponseException.class, + () -> TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TestHelpers.AD_BASE_PREVIEW_URI, input.getDetectorId()), + ImmutableMap.of(), + TestHelpers.toHttpEntity(input), + null + ) + ); + } + + public void testExecuteAnomalyDetectorWithNullDetectorId() throws Exception { + AnomalyDetectorExecutionInput input = new AnomalyDetectorExecutionInput( + null, + Instant.now().minusSeconds(60 * 10), + Instant.now(), + null + ); + TestHelpers + .assertFailWith( + ResponseException.class, + () -> TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TestHelpers.AD_BASE_PREVIEW_URI, input.getDetectorId()), + ImmutableMap.of(), + TestHelpers.toHttpEntity(input), + null + ) + ); + } + + public void testPreviewAnomalyDetectorWithDetector() throws Exception { + AnomalyDetector detector = createRandomAnomalyDetector(true, true, client()); + AnomalyDetectorExecutionInput input = new AnomalyDetectorExecutionInput( + detector.getId(), + Instant.now().minusSeconds(60 * 10), + Instant.now(), + detector + ); + Response response = TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TestHelpers.AD_BASE_PREVIEW_URI, input.getDetectorId()), + ImmutableMap.of(), + TestHelpers.toHttpEntity(input), + null, + false + ); + assertEquals("Execute anomaly detector failed", RestStatus.OK, TestHelpers.restStatus(response)); + } + + public void testPreviewAnomalyDetectorWithDetectorAndNoFeatures() throws Exception { + AnomalyDetector detector = createRandomAnomalyDetector(true, true, client()); + AnomalyDetectorExecutionInput input = new AnomalyDetectorExecutionInput( + detector.getId(), + Instant.now().minusSeconds(60 * 10), + Instant.now(), + TestHelpers.randomAnomalyDetectorWithEmptyFeature() + ); + TestHelpers + .assertFailWith( + ResponseException.class, + "Can't preview detector without feature", + () -> TestHelpers + .makeRequest( + client(), + "POST", + String.format(Locale.ROOT, TestHelpers.AD_BASE_PREVIEW_URI, input.getDetectorId()), + ImmutableMap.of(), + TestHelpers.toHttpEntity(input), + null + ) + ); + } + + public void testSearchAnomalyResult() throws Exception { + AnomalyResult anomalyResult = TestHelpers.randomAnomalyDetectResult(); + Response response = TestHelpers + .makeRequest( + adminClient(), + "POST", + "/.opendistro-anomaly-results/_doc/" + UUIDs.base64UUID(), + ImmutableMap.of(), + TestHelpers.toHttpEntity(anomalyResult), + null, + false + ); + assertEquals("Post anomaly result failed", RestStatus.CREATED, TestHelpers.restStatus(response)); + + SearchSourceBuilder search = (new SearchSourceBuilder()).query(QueryBuilders.termQuery("detector_id", anomalyResult.getConfigId())); + + updateClusterSettings(ADEnabledSetting.AD_ENABLED, false); + + Exception ex = expectThrows( + ResponseException.class, + () -> TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_RESULT_URI + "/_search", + ImmutableMap.of(), + new StringEntity(search.toString(), ContentType.APPLICATION_JSON), + null + ) + ); + assertThat(ex.getMessage(), containsString(ADCommonMessages.DISABLED_ERR_MSG)); + + updateClusterSettings(ADEnabledSetting.AD_ENABLED, true); + + Response searchResponse = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_RESULT_URI + "/_search", + ImmutableMap.of(), + new StringEntity(search.toString(), ContentType.APPLICATION_JSON), + null + ); + assertEquals("Search anomaly result failed", RestStatus.OK, TestHelpers.restStatus(searchResponse)); + + SearchSourceBuilder searchAll = SearchSourceBuilder.fromXContent(TestHelpers.parser("{\"query\":{\"match_all\":{}}}")); + Response searchAllResponse = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_RESULT_URI + "/_search", + ImmutableMap.of(), + new StringEntity(searchAll.toString(), ContentType.APPLICATION_JSON), + null + ); + assertEquals("Search anomaly result failed", RestStatus.OK, TestHelpers.restStatus(searchAllResponse)); + } + + public void testDeleteAnomalyDetector() throws Exception { + AnomalyDetector detector = createRandomAnomalyDetector(true, false, client()); + + updateClusterSettings(ADEnabledSetting.AD_ENABLED, false); + + Exception ex = expectThrows( + ResponseException.class, + () -> TestHelpers + .makeRequest(client(), "DELETE", TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId(), ImmutableMap.of(), "", null) + ); + assertThat(ex.getMessage(), containsString(ADCommonMessages.DISABLED_ERR_MSG)); + + updateClusterSettings(ADEnabledSetting.AD_ENABLED, true); + Response response = TestHelpers + .makeRequest(client(), "DELETE", TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId(), ImmutableMap.of(), "", null); + assertEquals("Delete anomaly detector failed", RestStatus.OK, TestHelpers.restStatus(response)); + } + + public void testDeleteAnomalyDetectorWhichNotExist() throws Exception { + TestHelpers + .assertFailWith( + ResponseException.class, + () -> TestHelpers + .makeRequest( + client(), + "DELETE", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + randomAlphaOfLength(5), + ImmutableMap.of(), + "", + null + ) + ); + } + + public void testDeleteAnomalyDetectorWithNoAdJob() throws Exception { + AnomalyDetector detector = createRandomAnomalyDetector(true, false, client()); + Response response = TestHelpers + .makeRequest(client(), "DELETE", TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId(), ImmutableMap.of(), "", null); + assertEquals("Delete anomaly detector failed", RestStatus.OK, TestHelpers.restStatus(response)); + } + + public void testDeleteAnomalyDetectorWithRunningAdJob() throws Exception { + AnomalyDetector detector = createRandomAnomalyDetector(true, false, client()); + Response startAdJobResponse = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId() + "/_start", + ImmutableMap.of(), + "", + null + ); + + assertEquals("Fail to start AD job", RestStatus.OK, TestHelpers.restStatus(startAdJobResponse)); + + TestHelpers + .assertFailWith( + ResponseException.class, + "Detector job is running", + () -> TestHelpers + .makeRequest( + client(), + "DELETE", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId(), + ImmutableMap.of(), + "", + null + ) + ); + } + + public void testUpdateAnomalyDetectorWithRunningAdJob() throws Exception { + AnomalyDetector detector = createAnomalyDetector(createIndexAndGetAnomalyDetector(INDEX_NAME), true, client()); + Response startAdJobResponse = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId() + "/_start", + ImmutableMap.of(), + "", + null + ); + + assertEquals("Fail to start AD job", RestStatus.OK, TestHelpers.restStatus(startAdJobResponse)); + + String newDescription = randomAlphaOfLength(5); + + AnomalyDetector newDetector = new AnomalyDetector( + detector.getId(), + detector.getVersion(), + detector.getName(), + newDescription, + detector.getTimeField(), + detector.getIndices(), + detector.getFeatureAttributes(), + detector.getFilterQuery(), + detector.getInterval(), + detector.getWindowDelay(), + detector.getShingleSize(), + detector.getUiMetadata(), + detector.getSchemaVersion(), + detector.getLastUpdateTime(), + null, + detector.getUser(), + null, + TestHelpers.randomImputationOption() + ); + + TestHelpers + .assertFailWith( + ResponseException.class, + "Detector job is running", + () -> TestHelpers + .makeRequest( + client(), + "PUT", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId(), + ImmutableMap.of(), + TestHelpers.toHttpEntity(newDetector), + null + ) + ); + } + + public void testGetDetectorWithAdJob() throws Exception { + AnomalyDetector detector = createRandomAnomalyDetector(true, false, client()); + Response startAdJobResponse = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId() + "/_start", + ImmutableMap.of(), + "", + null + ); + + assertEquals("Fail to start AD job", RestStatus.OK, TestHelpers.restStatus(startAdJobResponse)); + + ToXContentObject[] results = getAnomalyDetector(detector.getId(), true, client()); + assertEquals("Incorrect Location header", detector, results[0]); + assertEquals("Incorrect detector job name", detector.getId(), ((AnomalyDetectorJob) results[1]).getName()); + assertTrue(((AnomalyDetectorJob) results[1]).isEnabled()); + + results = getAnomalyDetector(detector.getId(), false, client()); + assertEquals("Incorrect Location header", detector, results[0]); + assertEquals("Should not return detector job", null, results[1]); + } + + public void testStartAdJobWithExistingDetector() throws Exception { + AnomalyDetector detector = createRandomAnomalyDetector(true, false, client()); + + updateClusterSettings(ADEnabledSetting.AD_ENABLED, false); + + Exception ex = expectThrows( + ResponseException.class, + () -> TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId() + "/_start", + ImmutableMap.of(), + "", + null + ) + ); + assertThat(ex.getMessage(), containsString(ADCommonMessages.DISABLED_ERR_MSG)); + + updateClusterSettings(ADEnabledSetting.AD_ENABLED, true); + Response startAdJobResponse = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId() + "/_start", + ImmutableMap.of(), + "", + null + ); + + assertEquals("Fail to start AD job", RestStatus.OK, TestHelpers.restStatus(startAdJobResponse)); + + startAdJobResponse = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId() + "/_start", + ImmutableMap.of(), + "", + null + ); + + assertEquals("Fail to start AD job", RestStatus.OK, TestHelpers.restStatus(startAdJobResponse)); + } + + public void testStartAdJobWithNonexistingDetectorIndex() throws Exception { + TestHelpers + .assertFailWith( + ResponseException.class, + "no such index [.opendistro-anomaly-detectors]", + () -> TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + randomAlphaOfLength(10) + "/_start", + ImmutableMap.of(), + "", + null + ) + ); + } + + public void testStartAdJobWithNonexistingDetector() throws Exception { + createRandomAnomalyDetector(true, false, client()); + TestHelpers + .assertFailWith( + ResponseException.class, + FAIL_TO_FIND_CONFIG_MSG, + () -> TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + randomAlphaOfLength(10) + "/_start", + ImmutableMap.of(), + "", + null + ) + ); + } + + public void testStopAdJob() throws Exception { + updateClusterSettings(ADEnabledSetting.AD_ENABLED, true); + AnomalyDetector detector = createRandomAnomalyDetector(true, false, client()); + Response startAdJobResponse = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId() + "/_start", + ImmutableMap.of(), + "", + null + ); + assertEquals("Fail to start AD job", RestStatus.OK, TestHelpers.restStatus(startAdJobResponse)); + + updateClusterSettings(ADEnabledSetting.AD_ENABLED, false); + + Exception ex = expectThrows( + ResponseException.class, + () -> TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId() + "/_stop", + ImmutableMap.of(), + "", + null + ) + ); + assertThat(ex.getMessage(), containsString(ADCommonMessages.DISABLED_ERR_MSG)); + + updateClusterSettings(ADEnabledSetting.AD_ENABLED, true); + + Response stopAdJobResponse = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId() + "/_stop", + ImmutableMap.of(), + "", + null + ); + assertEquals("Fail to stop AD job", RestStatus.OK, TestHelpers.restStatus(stopAdJobResponse)); + + stopAdJobResponse = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId() + "/_stop", + ImmutableMap.of(), + "", + null + ); + assertEquals("Fail to stop AD job", RestStatus.OK, TestHelpers.restStatus(stopAdJobResponse)); + } + + public void testStopNonExistingAdJobIndex() throws Exception { + AnomalyDetector detector = createRandomAnomalyDetector(true, true, client()); + TestHelpers + .assertFailWith( + ResponseException.class, + "no such index [.opendistro-anomaly-detector-jobs]", + () -> TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId() + "/_stop", + ImmutableMap.of(), + "", + null + ) + ); + } + + public void testStopNonExistingAdJob() throws Exception { + AnomalyDetector detector = createRandomAnomalyDetector(true, false, client()); + Response startAdJobResponse = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId() + "/_start", + ImmutableMap.of(), + "", + null + ); + assertEquals("Fail to start AD job", RestStatus.OK, TestHelpers.restStatus(startAdJobResponse)); + + TestHelpers + .assertFailWith( + ResponseException.class, + FAIL_TO_FIND_CONFIG_MSG, + () -> TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + randomAlphaOfLength(10) + "/_stop", + ImmutableMap.of(), + "", + null + ) + ); + } + + public void testStartDisabledAdjob() throws IOException { + AnomalyDetector detector = createRandomAnomalyDetector(true, false, client()); + Response startAdJobResponse = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId() + "/_start", + ImmutableMap.of(), + "", + null + ); + assertEquals("Fail to start AD job", RestStatus.OK, TestHelpers.restStatus(startAdJobResponse)); + + Response stopAdJobResponse = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId() + "/_stop", + ImmutableMap.of(), + "", + null + ); + assertEquals("Fail to stop AD job", RestStatus.OK, TestHelpers.restStatus(stopAdJobResponse)); + + startAdJobResponse = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId() + "/_start", + ImmutableMap.of(), + "", + null + ); + + assertEquals("Fail to start AD job", RestStatus.OK, TestHelpers.restStatus(startAdJobResponse)); + } + + public void testStartAdjobWithNullFeatures() throws Exception { + AnomalyDetector detectorWithoutFeature = TestHelpers.randomAnomalyDetector(null, null, Instant.now()); + String indexName = detectorWithoutFeature.getIndices().get(0); + TestHelpers.createIndex(client(), indexName, TestHelpers.toHttpEntity("{\"name\": \"test\"}")); + AnomalyDetector detector = createAnomalyDetector(detectorWithoutFeature, true, client()); + TestHelpers + .assertFailWith( + ResponseException.class, + "Can't start detector job as no features configured", + () -> TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId() + "/_start", + ImmutableMap.of(), + "", + null + ) + ); + } + + public void testStartAdjobWithEmptyFeatures() throws Exception { + AnomalyDetector detectorWithoutFeature = TestHelpers.randomAnomalyDetector(ImmutableList.of(), null, Instant.now()); + String indexName = detectorWithoutFeature.getIndices().get(0); + TestHelpers.createIndex(client(), indexName, TestHelpers.toHttpEntity("{\"name\": \"test\"}")); + AnomalyDetector detector = createAnomalyDetector(detectorWithoutFeature, true, client()); + TestHelpers + .assertFailWith( + ResponseException.class, + "Can't start detector job as no features configured", + () -> TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detector.getId() + "/_start", + ImmutableMap.of(), + "", + null + ) + ); + } + + public void testDefaultProfileAnomalyDetector() throws Exception { + AnomalyDetector detector = createRandomAnomalyDetector(true, true, client()); + + updateClusterSettings(ADEnabledSetting.AD_ENABLED, false); + + Exception ex = expectThrows(ResponseException.class, () -> getDetectorProfile(detector.getId())); + assertThat(ex.getMessage(), containsString(ADCommonMessages.DISABLED_ERR_MSG)); + + updateClusterSettings(ADEnabledSetting.AD_ENABLED, true); + + Response profileResponse = getDetectorProfile(detector.getId()); + assertEquals("Incorrect profile status", RestStatus.OK, TestHelpers.restStatus(profileResponse)); + } + + public void testAllProfileAnomalyDetector() throws Exception { + AnomalyDetector detector = createRandomAnomalyDetector(true, true, client()); + Response profileResponse = getDetectorProfile(detector.getId(), true); + assertEquals("Incorrect profile status", RestStatus.OK, TestHelpers.restStatus(profileResponse)); + } + + public void testCustomizedProfileAnomalyDetector() throws Exception { + AnomalyDetector detector = createRandomAnomalyDetector(true, true, client()); + Response profileResponse = getDetectorProfile(detector.getId(), true, "/models/", client()); + assertEquals("Incorrect profile status", RestStatus.OK, TestHelpers.restStatus(profileResponse)); + } + + public void testSearchAnomalyDetectorCountNoIndex() throws Exception { + Response countResponse = getSearchDetectorCount(); + Map responseMap = entityAsMap(countResponse); + Integer count = (Integer) responseMap.get("count"); + assertEquals((long) count, 0); + } + + public void testSearchAnomalyDetectorCount() throws Exception { + AnomalyDetector detector = createRandomAnomalyDetector(true, true, client()); + Response countResponse = getSearchDetectorCount(); + Map responseMap = entityAsMap(countResponse); + Integer count = (Integer) responseMap.get("count"); + assertEquals((long) count, 1); + } + + public void testSearchAnomalyDetectorMatchNoIndex() throws Exception { + Response matchResponse = getSearchDetectorMatch("name"); + Map responseMap = entityAsMap(matchResponse); + boolean nameExists = (boolean) responseMap.get("match"); + assertEquals(nameExists, false); + } + + public void testSearchAnomalyDetectorNoMatch() throws Exception { + AnomalyDetector detector = createRandomAnomalyDetector(true, true, client()); + Response matchResponse = getSearchDetectorMatch(detector.getName()); + Map responseMap = entityAsMap(matchResponse); + boolean nameExists = (boolean) responseMap.get("match"); + assertEquals(nameExists, true); + } + + public void testSearchAnomalyDetectorMatch() throws Exception { + AnomalyDetector detector = createRandomAnomalyDetector(true, true, client()); + Response matchResponse = getSearchDetectorMatch(detector.getName() + "newDetector"); + Map responseMap = entityAsMap(matchResponse); + boolean nameExists = (boolean) responseMap.get("match"); + assertEquals(nameExists, false); + } + + public void testRunDetectorWithNoEnabledFeature() throws Exception { + AnomalyDetector detector = createRandomAnomalyDetector(true, true, client(), false); + Assert.assertNotNull(detector.getId()); + Instant now = Instant.now(); + ResponseException e = expectThrows( + ResponseException.class, + () -> startAnomalyDetector(detector.getId(), new DateRange(now.minus(10, ChronoUnit.DAYS), now), client()) + ); + assertTrue(e.getMessage().contains("Can't start detector job as no enabled features configured")); + } + + public void testDeleteAnomalyDetectorWhileRunning() throws Exception { + AnomalyDetector detector = createRandomAnomalyDetector(true, true, client()); + Assert.assertNotNull(detector.getId()); + Instant now = Instant.now(); + Response response = startAnomalyDetector(detector.getId(), new DateRange(now.minus(10, ChronoUnit.DAYS), now), client()); + Assert.assertThat(response.getStatusLine().toString(), CoreMatchers.containsString("200 OK")); + + // Deleting detector should fail while its running + Exception exception = expectThrows(IOException.class, () -> { deleteAnomalyDetector(detector.getId(), client()); }); + Assert.assertTrue(exception.getMessage().contains("Detector is running")); + } + + public void testBackwardCompatibilityWithOpenDistro() throws IOException { + // Create a detector + AnomalyDetector detector = createIndexAndGetAnomalyDetector(INDEX_NAME); + // Verify the detector is created using legacy _opendistro API + Response response = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.LEGACY_OPENDISTRO_AD_BASE_DETECTORS_URI, + ImmutableMap.of(), + TestHelpers.toHttpEntity(detector), + null + ); + assertEquals("Create anomaly detector failed", RestStatus.CREATED, TestHelpers.restStatus(response)); + Map responseMap = entityAsMap(response); + String id = (String) responseMap.get("_id"); + int version = (int) responseMap.get("_version"); + assertNotEquals("response is missing Id", AnomalyDetector.NO_ID, id); + assertTrue("incorrect version", version > 0); + + // Get the detector using new _plugins API + AnomalyDetector createdDetector = getAnomalyDetector(id, client()); + assertEquals("Get anomaly detector failed", createdDetector.getId(), id); + + // Delete the detector using legacy _opendistro API + response = TestHelpers + .makeRequest( + client(), + "DELETE", + TestHelpers.LEGACY_OPENDISTRO_AD_BASE_DETECTORS_URI + "/" + createdDetector.getId(), + ImmutableMap.of(), + "", + null + ); + assertEquals("Delete anomaly detector failed", RestStatus.OK, TestHelpers.restStatus(response)); + + } + + public void testValidateAnomalyDetectorWithDuplicateName() throws Exception { + AnomalyDetector detector = createAnomalyDetector(createIndexAndGetAnomalyDetector(INDEX_NAME), true, client()); + Response resp = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/_validate", + ImmutableMap.of(), + TestHelpers + .toHttpEntity( + "{\"name\":\"" + + detector.getName() + + "\",\"description\":\"Test detector\",\"time_field\":\"timestamp\"," + + "\"indices\":[\"" + + INDEX_NAME + + "\"],\"feature_attributes\":[{\"feature_name\":\"cpu-sum\",\"" + + "feature_enabled\":true,\"aggregation_query\":{\"total_cpu\":{\"sum\":{\"field\":\"cpu\"}}}}," + + "{\"feature_name\":\"error-sum\",\"feature_enabled\":true,\"aggregation_query\":" + + "{\"total_error\":" + + "{\"sum\":{\"field\":\"error\"}}}}],\"filter_query\":{\"bool\":{\"filter\":[{\"exists\":" + + "{\"field\":" + + "\"cpu\",\"boost\":1}}],\"adjust_pure_negative\":true,\"boost\":1}},\"detection_interval\":" + + "{\"period\":{\"interval\":1,\"unit\":\"Minutes\"}}," + + "\"window_delay\":{\"period\":{\"interval\":2,\"unit\":\"Minutes\"}}}" + ), + null + ); + Map responseMap = entityAsMap(resp); + @SuppressWarnings("unchecked") + Map> messageMap = (Map>) XContentMapValues + .extractValue("detector", responseMap); + assertEquals("Validation returned duplicate detector name message", RestStatus.OK, TestHelpers.restStatus(resp)); + String errorMsg = String.format(Locale.ROOT, DUPLICATE_DETECTOR_MSG, detector.getName(), "[" + detector.getId() + "]"); + assertEquals("duplicate error message", errorMsg, messageMap.get("name").get("message")); + } + + public void testValidateAnomalyDetectorWithNoTimeField() throws Exception { + TestHelpers.createIndex(client(), "test-index", TestHelpers.toHttpEntity("{\"timestamp\": " + Instant.now().toEpochMilli() + "}")); + Response resp = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/_validate", + ImmutableMap.of(), + TestHelpers + .toHttpEntity( + "{\"name\":\"test\",\"description\":\"\"" + + ",\"indices\":[\"test-index\"],\"feature_attributes\":[{\"feature_name\":\"test\"," + + "\"feature_enabled\":true,\"aggregation_query\":{\"test\":{\"sum\":{\"field\":\"value\"}}}}]," + + "\"filter_query\":{},\"detection_interval\":{\"period\":{\"interval\":1,\"unit\":\"Minutes\"}}," + + "\"window_delay\":{\"period\":{\"interval\":1,\"unit\":\"Minutes\"}}}" + ), + null + ); + Map responseMap = entityAsMap(resp); + @SuppressWarnings("unchecked") + Map> messageMap = (Map>) XContentMapValues + .extractValue("detector", responseMap); + assertEquals("Validation response returned", RestStatus.OK, TestHelpers.restStatus(resp)); + assertEquals("time field missing", CommonMessages.NULL_TIME_FIELD, messageMap.get("time_field").get("message")); + } + + public void testValidateAnomalyDetectorWithIncorrectShingleSize() throws Exception { + TestHelpers.createIndex(client(), "test-index", TestHelpers.toHttpEntity("{\"timestamp\": " + Instant.now().toEpochMilli() + "}")); + Response resp = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/_validate", + ImmutableMap.of(), + TestHelpers + .toHttpEntity( + "{\"name\":\"" + + "test-detector" + + "\",\"description\":\"Test detector\",\"time_field\":\"timestamp\"," + + "\"indices\":[\"test-index\"],\"feature_attributes\":[{\"feature_name\":\"cpu-sum\",\"" + + "feature_enabled\":true,\"aggregation_query\":{\"total_cpu\":{\"sum\":{\"field\":\"cpu\"}}}}," + + "{\"feature_name\":\"error-sum\",\"feature_enabled\":true,\"aggregation_query\":" + + "{\"total_error\":" + + "{\"sum\":{\"field\":\"error\"}}}}],\"filter_query\":{\"bool\":{\"filter\":[{\"exists\":" + + "{\"field\":" + + "\"cpu\",\"boost\":1}}],\"adjust_pure_negative\":true,\"boost\":1}},\"detection_interval\":" + + "{\"period\":{\"interval\":1,\"unit\":\"Minutes\"}}," + + "\"window_delay\":{\"period\":{\"interval\":2,\"unit\":\"Minutes\"}}," + + "\"shingle_size\": 2000}" + ), + null + ); + Map responseMap = entityAsMap(resp); + @SuppressWarnings("unchecked") + Map> messageMap = (Map>) XContentMapValues + .extractValue("detector", responseMap); + String errorMessage = "Shingle size must be a positive integer no larger than " + + TimeSeriesSettings.MAX_SHINGLE_SIZE + + ". Got 2000"; + assertEquals("shingle size error message", errorMessage, messageMap.get("shingle_size").get("message")); + } + + public void testValidateAnomalyDetectorWithNoIssue() throws Exception { + AnomalyDetector detector = createIndexAndGetAnomalyDetector(INDEX_NAME); + Response resp = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/_validate/detector", + ImmutableMap.of(), + TestHelpers.toHttpEntity(detector), + null + ); + Map responseMap = entityAsMap(resp); + assertEquals("no issue, empty response body", new HashMap(), responseMap); + } + + public void testValidateAnomalyDetectorOnWrongValidationType() throws Exception { + AnomalyDetector detector = createIndexAndGetAnomalyDetector(INDEX_NAME); + TestHelpers + .assertFailWith( + ResponseException.class, + ADCommonMessages.NOT_EXISTENT_VALIDATION_TYPE, + () -> TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/_validate/models", + ImmutableMap.of(), + TestHelpers.toHttpEntity(detector), + null + ) + ); + } + + public void testValidateAnomalyDetectorWithEmptyIndices() throws Exception { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(TIME_FIELD, INDEX_NAME); + TestHelpers + .makeRequest( + client(), + "PUT", + "/" + detector.getIndices().get(0), + ImmutableMap.of(), + TestHelpers + .toHttpEntity( + "{\"settings\":{\"number_of_shards\":1}," + + " \"mappings\":{\"properties\":" + + "{\"timestamp\":{\"type\":\"date\"}}}}" + + "{\"field1\":{\"type\":\"text\"}}}}" + ), + null + ); + Response resp = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/_validate", + ImmutableMap.of(), + TestHelpers.toHttpEntity(detector), + null + ); + Map responseMap = entityAsMap(resp); + @SuppressWarnings("unchecked") + Map> messageMap = (Map>) XContentMapValues + .extractValue("detector", responseMap); + assertEquals("Validation returned message regarding empty indices", RestStatus.OK, TestHelpers.restStatus(resp)); + String errorMessage = NO_DOCS_IN_USER_INDEX_MSG + "[" + detector.getIndices().get(0) + "]"; + assertEquals("duplicate error message", errorMessage, messageMap.get("indices").get("message")); + } + + public void testValidateAnomalyDetectorWithInvalidName() throws Exception { + TestHelpers.createIndex(client(), "test-index", TestHelpers.toHttpEntity("{\"timestamp\": " + Instant.now().toEpochMilli() + "}")); + Response resp = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/_validate/detector", + ImmutableMap.of(), + TestHelpers + .toHttpEntity( + "{\"name\":\"#@$3\",\"description\":\"\",\"time_field\":\"timestamp\"" + + ",\"indices\":[\"test-index\"],\"feature_attributes\":[{\"feature_name\":\"test\"," + + "\"feature_enabled\":true,\"aggregation_query\":{\"test\":{\"sum\":{\"field\":\"value\"}}}}]," + + "\"filter_query\":{},\"detection_interval\":{\"period\":{\"interval\":1,\"unit\":\"Minutes\"}}," + + "\"window_delay\":{\"period\":{\"interval\":1,\"unit\":\"Minutes\"}}}" + ), + null + ); + Map responseMap = entityAsMap(resp); + @SuppressWarnings("unchecked") + Map> messageMap = (Map>) XContentMapValues + .extractValue("detector", responseMap); + assertEquals("invalid detector Name", CommonMessages.INVALID_NAME, messageMap.get("name").get("message")); + } + + public void testValidateAnomalyDetectorWithFeatureQueryReturningNoData() throws Exception { + Feature emptyFeature = TestHelpers.randomFeature("f-empty", "cpu", "avg", true); + AnomalyDetector detector = createIndexAndGetAnomalyDetector(INDEX_NAME, ImmutableList.of(emptyFeature)); + Response resp = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/_validate/detector", + ImmutableMap.of(), + TestHelpers.toHttpEntity(detector), + null + ); + Map responseMap = entityAsMap(resp); + @SuppressWarnings("unchecked") + Map> messageMap = (Map>) XContentMapValues + .extractValue("detector", responseMap); + assertEquals( + "empty data", + CommonMessages.FEATURE_WITH_EMPTY_DATA_MSG + "f-empty", + messageMap.get("feature_attributes").get("message") + ); + } + + public void testValidateAnomalyDetectorWithFeatureQueryRuntimeException() throws Exception { + Feature nonNumericFeature = TestHelpers.randomFeature("non-numeric-feature", "_index", "avg", true); + AnomalyDetector detector = createIndexAndGetAnomalyDetector(INDEX_NAME, ImmutableList.of(nonNumericFeature)); + Response resp = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/_validate/detector", + ImmutableMap.of(), + TestHelpers.toHttpEntity(detector), + null + ); + Map responseMap = entityAsMap(resp); + @SuppressWarnings("unchecked") + Map> messageMap = (Map>) XContentMapValues + .extractValue("detector", responseMap); + assertEquals( + "runtime exception", + CommonMessages.FEATURE_WITH_INVALID_QUERY_MSG + "non-numeric-feature", + messageMap.get("feature_attributes").get("message") + ); + } + + public void testValidateAnomalyDetectorWithWrongCategoryField() throws Exception { + AnomalyDetector detector = TestHelpers + .randomAnomalyDetectorUsingCategoryFields( + randomAlphaOfLength(5), + TIME_FIELD, + ImmutableList.of("index-test"), + Arrays.asList("host.keyword") + ); + TestHelpers.createIndexWithTimeField(client(), "index-test", TIME_FIELD); + Response resp = TestHelpers + .makeRequest( + client(), + "POST", + TestHelpers.AD_BASE_DETECTORS_URI + "/_validate/detector", + ImmutableMap.of(), + TestHelpers.toHttpEntity(detector), + null + ); + Map responseMap = entityAsMap(resp); + @SuppressWarnings("unchecked") + Map> messageMap = (Map>) XContentMapValues + .extractValue("detector", responseMap); + assertEquals( + "non-existing category", + String.format(Locale.ROOT, AbstractAnomalyDetectorActionHandler.CATEGORY_NOT_FOUND_ERR_MSG, "host.keyword"), + messageMap.get("category_field").get("message") + ); + + } + + public void testSearchTopAnomalyResultsWithInvalidInputs() throws IOException { + String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT); + Map categoryFieldsAndTypes = new HashMap() { + { + put("keyword-field", "keyword"); + put("ip-field", "ip"); + } + }; + String testIndexData = "{\"keyword-field\": \"field-1\", \"ip-field\": \"1.2.3.4\", \"timestamp\": 1}"; + TestHelpers.createIndexWithHCADFields(client(), indexName, categoryFieldsAndTypes); + TestHelpers.ingestDataToIndex(client(), indexName, TestHelpers.toHttpEntity(testIndexData)); + AnomalyDetector detector = createAnomalyDetector( + TestHelpers + .randomAnomalyDetectorUsingCategoryFields( + randomAlphaOfLength(10), + TIME_FIELD, + ImmutableList.of(indexName), + categoryFieldsAndTypes.keySet().stream().collect(Collectors.toList()) + ), + true, + client() + ); + + // Missing start time + Exception missingStartTimeException = expectThrows( + IOException.class, + () -> { searchTopAnomalyResults(detector.getId(), false, "{\"end_time_ms\":2}", client()); } + ); + assertTrue(missingStartTimeException.getMessage().contains("Must set both start time and end time with epoch of milliseconds")); + + // Missing end time + Exception missingEndTimeException = expectThrows( + IOException.class, + () -> { searchTopAnomalyResults(detector.getId(), false, "{\"start_time_ms\":1}", client()); } + ); + assertTrue(missingEndTimeException.getMessage().contains("Must set both start time and end time with epoch of milliseconds")); + + // Start time > end time + Exception invalidTimeException = expectThrows( + IOException.class, + () -> { searchTopAnomalyResults(detector.getId(), false, "{\"start_time_ms\":2, \"end_time_ms\":1}", client()); } + ); + assertTrue(invalidTimeException.getMessage().contains("Start time should be before end time")); + + // Invalid detector ID + Exception invalidDetectorIdException = expectThrows( + IOException.class, + () -> { searchTopAnomalyResults(detector.getId() + "-invalid", false, "{\"start_time_ms\":1, \"end_time_ms\":2}", client()); } + ); + assertTrue(invalidDetectorIdException.getMessage().contains("Can't find config with id")); + + // Invalid order field + Exception invalidOrderException = expectThrows( + IOException.class, + () -> { + searchTopAnomalyResults( + detector.getId(), + false, + "{\"start_time_ms\":1, \"end_time_ms\":2, \"order\":\"invalid-order\"}", + client() + ); + } + ); + assertTrue(invalidOrderException.getMessage().contains("Ordering by invalid-order is not a valid option")); + + // Negative size field + Exception negativeSizeException = expectThrows( + IOException.class, + () -> { searchTopAnomalyResults(detector.getId(), false, "{\"start_time_ms\":1, \"end_time_ms\":2, \"size\":-1}", client()); } + ); + assertTrue(negativeSizeException.getMessage().contains("Size must be a positive integer")); + + // Zero size field + Exception zeroSizeException = expectThrows( + IOException.class, + () -> { searchTopAnomalyResults(detector.getId(), false, "{\"start_time_ms\":1, \"end_time_ms\":2, \"size\":0}", client()); } + ); + assertTrue(zeroSizeException.getMessage().contains("Size must be a positive integer")); + + // Too large size field + Exception tooLargeSizeException = expectThrows( + IOException.class, + () -> { + searchTopAnomalyResults(detector.getId(), false, "{\"start_time_ms\":1, \"end_time_ms\":2, \"size\":9999999}", client()); + } + ); + assertTrue(tooLargeSizeException.getMessage().contains("Size cannot exceed")); + + // No existing task ID for detector + Exception noTaskIdException = expectThrows( + IOException.class, + () -> { searchTopAnomalyResults(detector.getId(), true, "{\"start_time_ms\":1, \"end_time_ms\":2}", client()); } + ); + assertTrue(noTaskIdException.getMessage().contains("No historical tasks found for detector ID " + detector.getId())); + + // Invalid category fields + Exception invalidCategoryFieldsException = expectThrows(IOException.class, () -> { + searchTopAnomalyResults( + detector.getId(), + false, + "{\"start_time_ms\":1, \"end_time_ms\":2, \"category_field\":[\"invalid-field\"]}", + client() + ); + }); + assertTrue( + invalidCategoryFieldsException + .getMessage() + .contains("Category field invalid-field doesn't exist for detector ID " + detector.getId()) + ); + + // Using detector with no category fields + AnomalyDetector detectorWithNoCategoryFields = createAnomalyDetector( + TestHelpers + .randomAnomalyDetectorUsingCategoryFields( + randomAlphaOfLength(10), + TIME_FIELD, + ImmutableList.of(indexName), + ImmutableList.of() + ), + true, + client() + ); + Exception noCategoryFieldsException = expectThrows( + IOException.class, + () -> { + searchTopAnomalyResults(detectorWithNoCategoryFields.getId(), false, "{\"start_time_ms\":1, \"end_time_ms\":2}", client()); + } + ); + assertTrue( + noCategoryFieldsException + .getMessage() + .contains("No category fields found for detector ID " + detectorWithNoCategoryFields.getId()) + ); + } + + public void testSearchTopAnomalyResultsOnNonExistentResultIndex() throws IOException { + String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT); + Map categoryFieldsAndTypes = new HashMap() { + { + put("keyword-field", "keyword"); + put("ip-field", "ip"); + } + }; + String testIndexData = "{\"keyword-field\": \"test-value\"}"; + TestHelpers.createIndexWithHCADFields(client(), indexName, categoryFieldsAndTypes); + TestHelpers.ingestDataToIndex(client(), indexName, TestHelpers.toHttpEntity(testIndexData)); + AnomalyDetector detector = createAnomalyDetector( + TestHelpers + .randomAnomalyDetectorUsingCategoryFields( + randomAlphaOfLength(10), + TIME_FIELD, + ImmutableList.of(indexName), + categoryFieldsAndTypes.keySet().stream().collect(Collectors.toList()) + ), + true, + client() + ); + + // Delete any existing result index + if (indexExistsWithAdminClient(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS)) { + deleteIndexWithAdminClient(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS + "*"); + } + Response response = searchTopAnomalyResults( + detector.getId(), + false, + "{\"size\":3,\"category_field\":[\"keyword-field\"]," + "\"start_time_ms\":0, \"end_time_ms\":1}", + client() + ); + Map responseMap = entityAsMap(response); + @SuppressWarnings("unchecked") + List> buckets = (ArrayList>) XContentMapValues.extractValue("buckets", responseMap); + assertEquals(0, buckets.size()); + } + + public void testSearchTopAnomalyResultsOnEmptyResultIndex() throws IOException { + String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT); + Map categoryFieldsAndTypes = new HashMap() { + { + put("keyword-field", "keyword"); + put("ip-field", "ip"); + } + }; + String testIndexData = "{\"keyword-field\": \"test-value\"}"; + TestHelpers.createIndexWithHCADFields(client(), indexName, categoryFieldsAndTypes); + TestHelpers.ingestDataToIndex(client(), indexName, TestHelpers.toHttpEntity(testIndexData)); + AnomalyDetector detector = createAnomalyDetector( + TestHelpers + .randomAnomalyDetectorUsingCategoryFields( + randomAlphaOfLength(10), + TIME_FIELD, + ImmutableList.of(indexName), + categoryFieldsAndTypes.keySet().stream().collect(Collectors.toList()) + ), + true, + client() + ); + + if (indexExistsWithAdminClient(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS)) { + deleteIndexWithAdminClient(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS + "*"); + } + TestHelpers.createEmptyAnomalyResultIndex(adminClient()); + Response response = searchTopAnomalyResults( + detector.getId(), + false, + "{\"size\":3,\"category_field\":[\"keyword-field\"]," + "\"start_time_ms\":0, \"end_time_ms\":1}", + client() + ); + Map responseMap = entityAsMap(response); + @SuppressWarnings("unchecked") + List> buckets = (ArrayList>) XContentMapValues.extractValue("buckets", responseMap); + assertEquals(0, buckets.size()); + } + + public void testSearchTopAnomalyResultsOnPopulatedResultIndex() throws IOException { + String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT); + Map categoryFieldsAndTypes = new HashMap() { + { + put("keyword-field", "keyword"); + put("ip-field", "ip"); + } + }; + String testIndexData = "{\"keyword-field\": \"field-1\", \"ip-field\": \"1.2.3.4\", \"timestamp\": 1}"; + TestHelpers.createIndexWithHCADFields(client(), indexName, categoryFieldsAndTypes); + TestHelpers.ingestDataToIndex(client(), indexName, TestHelpers.toHttpEntity(testIndexData)); + AnomalyDetector detector = createAnomalyDetector( + TestHelpers + .randomAnomalyDetectorUsingCategoryFields( + randomAlphaOfLength(10), + TIME_FIELD, + ImmutableList.of(indexName), + categoryFieldsAndTypes.keySet().stream().collect(Collectors.toList()) + ), + true, + client() + ); + + // Ingest some sample results + if (!indexExistsWithAdminClient(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS)) { + TestHelpers.createEmptyAnomalyResultIndex(adminClient()); + } + Map entityAttrs1 = new HashMap() { + { + put("keyword-field", "field-1"); + put("ip-field", "1.2.3.4"); + } + }; + Map entityAttrs2 = new HashMap() { + { + put("keyword-field", "field-2"); + put("ip-field", "5.6.7.8"); + } + }; + Map entityAttrs3 = new HashMap() { + { + put("keyword-field", "field-2"); + put("ip-field", "5.6.7.8"); + } + }; + AnomalyResult anomalyResult1 = TestHelpers + .randomHCADAnomalyDetectResult(detector.getId(), null, entityAttrs1, 0.5, 0.8, null, 5L, 5L); + AnomalyResult anomalyResult2 = TestHelpers + .randomHCADAnomalyDetectResult(detector.getId(), null, entityAttrs2, 0.5, 0.5, null, 5L, 5L); + AnomalyResult anomalyResult3 = TestHelpers + .randomHCADAnomalyDetectResult(detector.getId(), null, entityAttrs3, 0.5, 0.2, null, 5L, 5L); + + TestHelpers.ingestDataToIndex(adminClient(), ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, TestHelpers.toHttpEntity(anomalyResult1)); + TestHelpers.ingestDataToIndex(adminClient(), ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, TestHelpers.toHttpEntity(anomalyResult2)); + TestHelpers.ingestDataToIndex(adminClient(), ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, TestHelpers.toHttpEntity(anomalyResult3)); + + // Sorting by severity + Response severityResponse = searchTopAnomalyResults( + detector.getId(), + false, + "{\"category_field\":[\"keyword-field\"]," + "\"start_time_ms\":0, \"end_time_ms\":10, \"order\":\"severity\"}", + client() + ); + Map severityResponseMap = entityAsMap(severityResponse); + @SuppressWarnings("unchecked") + List> severityBuckets = (ArrayList>) XContentMapValues + .extractValue("buckets", severityResponseMap); + assertEquals(2, severityBuckets.size()); + @SuppressWarnings("unchecked") + Map severityBucketKey1 = (Map) severityBuckets.get(0).get("key"); + @SuppressWarnings("unchecked") + Map severityBucketKey2 = (Map) severityBuckets.get(1).get("key"); + assertEquals("field-1", severityBucketKey1.get("keyword-field")); + assertEquals("field-2", severityBucketKey2.get("keyword-field")); + + // Sorting by occurrence + Response occurrenceResponse = searchTopAnomalyResults( + detector.getId(), + false, + "{\"category_field\":[\"keyword-field\"]," + "\"start_time_ms\":0, \"end_time_ms\":10, \"order\":\"occurrence\"}", + client() + ); + Map occurrenceResponseMap = entityAsMap(occurrenceResponse); + @SuppressWarnings("unchecked") + List> occurrenceBuckets = (ArrayList>) XContentMapValues + .extractValue("buckets", occurrenceResponseMap); + assertEquals(2, occurrenceBuckets.size()); + @SuppressWarnings("unchecked") + Map occurrenceBucketKey1 = (Map) occurrenceBuckets.get(0).get("key"); + @SuppressWarnings("unchecked") + Map occurrenceBucketKey2 = (Map) occurrenceBuckets.get(1).get("key"); + assertEquals("field-2", occurrenceBucketKey1.get("keyword-field")); + assertEquals("field-1", occurrenceBucketKey2.get("keyword-field")); + + // Sorting using all category fields + Response allFieldsResponse = searchTopAnomalyResults( + detector.getId(), + false, + "{\"category_field\":[\"keyword-field\", \"ip-field\"]," + "\"start_time_ms\":0, \"end_time_ms\":10, \"order\":\"severity\"}", + client() + ); + Map allFieldsResponseMap = entityAsMap(allFieldsResponse); + @SuppressWarnings("unchecked") + List> allFieldsBuckets = (ArrayList>) XContentMapValues + .extractValue("buckets", allFieldsResponseMap); + assertEquals(2, allFieldsBuckets.size()); + @SuppressWarnings("unchecked") + Map allFieldsBucketKey1 = (Map) allFieldsBuckets.get(0).get("key"); + @SuppressWarnings("unchecked") + Map allFieldsBucketKey2 = (Map) allFieldsBuckets.get(1).get("key"); + assertEquals("field-1", allFieldsBucketKey1.get("keyword-field")); + assertEquals("1.2.3.4", allFieldsBucketKey1.get("ip-field")); + assertEquals("field-2", allFieldsBucketKey2.get("keyword-field")); + assertEquals("5.6.7.8", allFieldsBucketKey2.get("ip-field")); + } + + public void testSearchTopAnomalyResultsWithCustomResultIndex() throws IOException { + String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT); + String customResultIndexName = ADCommonName.CUSTOM_RESULT_INDEX_PREFIX + randomAlphaOfLength(5).toLowerCase(Locale.ROOT); + Map categoryFieldsAndTypes = new HashMap() { + { + put("keyword-field", "keyword"); + put("ip-field", "ip"); + } + }; + String testIndexData = "{\"keyword-field\": \"field-1\", \"ip-field\": \"1.2.3.4\", \"timestamp\": 1}"; + TestHelpers.createIndexWithHCADFields(client(), indexName, categoryFieldsAndTypes); + TestHelpers.ingestDataToIndex(client(), indexName, TestHelpers.toHttpEntity(testIndexData)); + AnomalyDetector detector = createAnomalyDetector( + TestHelpers + .randomAnomalyDetectorUsingCategoryFields( + randomAlphaOfLength(10), + TIME_FIELD, + ImmutableList.of(indexName), + categoryFieldsAndTypes.keySet().stream().collect(Collectors.toList()), + customResultIndexName + ), + true, + client() + ); + + Map entityAttrs = new HashMap() { + { + put("keyword-field", "field-1"); + put("ip-field", "1.2.3.4"); + } + }; + AnomalyResult anomalyResult = TestHelpers + .randomHCADAnomalyDetectResult(detector.getId(), null, entityAttrs, 0.5, 0.8, null, 5L, 5L); + TestHelpers.ingestDataToIndex(client(), customResultIndexName, TestHelpers.toHttpEntity(anomalyResult)); + + Response response = searchTopAnomalyResults(detector.getId(), false, "{\"start_time_ms\":0, \"end_time_ms\":10}", client()); + Map responseMap = entityAsMap(response); + @SuppressWarnings("unchecked") + List> buckets = (ArrayList>) XContentMapValues.extractValue("buckets", responseMap); + assertEquals(1, buckets.size()); + @SuppressWarnings("unchecked") + Map bucketKey1 = (Map) buckets.get(0).get("key"); + assertEquals("field-1", bucketKey1.get("keyword-field")); + assertEquals("1.2.3.4", bucketKey1.get("ip-field")); + } +} diff --git a/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java b/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java index fac7f9dc4..7d0be2ae9 100644 --- a/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java +++ b/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java @@ -39,8 +39,8 @@ import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; -import org.opensearch.rest.RestStatus; import org.opensearch.timeseries.TestHelpers; import com.google.common.collect.ImmutableList; diff --git a/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java-e b/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java-e new file mode 100644 index 000000000..6ccde11d7 --- /dev/null +++ b/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java-e @@ -0,0 +1,324 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS; +import static org.opensearch.timeseries.TestHelpers.AD_BASE_STATS_URI; +import static org.opensearch.timeseries.TestHelpers.HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS; +import static org.opensearch.timeseries.stats.StatNames.AD_TOTAL_BATCH_TASK_EXECUTION_COUNT; +import static org.opensearch.timeseries.stats.StatNames.MULTI_ENTITY_DETECTOR_COUNT; +import static org.opensearch.timeseries.stats.StatNames.SINGLE_ENTITY_DETECTOR_COUNT; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; + +import org.apache.hc.core5.http.ParseException; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.junit.Before; +import org.junit.Ignore; +import org.opensearch.ad.HistoricalAnalysisRestTestCase; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskProfile; +import org.opensearch.ad.model.ADTaskState; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.timeseries.TestHelpers; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +public class HistoricalAnalysisRestApiIT extends HistoricalAnalysisRestTestCase { + + @Before + @Override + public void setUp() throws Exception { + super.categoryFieldDocCount = 3; + super.setUp(); + updateClusterSettings(MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS.getKey(), 2); + updateClusterSettings(BATCH_TASK_PIECE_INTERVAL_SECONDS.getKey(), 5); + updateClusterSettings(MAX_BATCH_TASK_PER_NODE.getKey(), 10); + } + + public void testHistoricalAnalysisForSingleEntityDetector() throws Exception { + List startHistoricalAnalysisResult = startHistoricalAnalysis(0); + String detectorId = startHistoricalAnalysisResult.get(0); + String taskId = startHistoricalAnalysisResult.get(1); + checkIfTaskCanFinishCorrectly(detectorId, taskId, HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS); + } + + public void testHistoricalAnalysisForSingleEntityDetectorWithCustomResultIndex() throws Exception { + String resultIndex = ADCommonName.CUSTOM_RESULT_INDEX_PREFIX + randomAlphaOfLength(5).toLowerCase(Locale.ROOT); + List startHistoricalAnalysisResult = startHistoricalAnalysis(0, resultIndex); + String detectorId = startHistoricalAnalysisResult.get(0); + String taskId = startHistoricalAnalysisResult.get(1); + checkIfTaskCanFinishCorrectly(detectorId, taskId, HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS); + Response searchResponse = searchTaskResult(resultIndex, taskId); + assertEquals("Search anomaly result failed", RestStatus.OK, TestHelpers.restStatus(searchResponse)); + } + + public void testHistoricalAnalysisForSingleCategoryHC() throws Exception { + List startHistoricalAnalysisResult = startHistoricalAnalysis(1); + String detectorId = startHistoricalAnalysisResult.get(0); + String taskId = startHistoricalAnalysisResult.get(1); + checkIfTaskCanFinishCorrectly(detectorId, taskId, HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS); + } + + public void testHistoricalAnalysisForMultiCategoryHC() throws Exception { + List startHistoricalAnalysisResult = startHistoricalAnalysis(2); + String detectorId = startHistoricalAnalysisResult.get(0); + String taskId = startHistoricalAnalysisResult.get(1); + checkIfTaskCanFinishCorrectly(detectorId, taskId, HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS); + } + + private void checkIfTaskCanFinishCorrectly(String detectorId, String taskId, Set states) throws InterruptedException { + List results = waitUntilTaskDone(detectorId); + ADTaskProfile endTaskProfile = (ADTaskProfile) results.get(0); + Integer retryCount = (Integer) results.get(1); + ADTask stoppedAdTask = endTaskProfile.getAdTask(); + assertEquals(taskId, stoppedAdTask.getTaskId()); + if (retryCount < MAX_RETRY_TIMES) { + // It's possible that historical analysis still running after max retry times + assertTrue(states.contains(stoppedAdTask.getState())); + } + } + + @SuppressWarnings("unchecked") + private List startHistoricalAnalysis(int categoryFieldSize) throws Exception { + return startHistoricalAnalysis(categoryFieldSize, null); + } + + @SuppressWarnings("unchecked") + private List startHistoricalAnalysis(int categoryFieldSize, String resultIndex) throws Exception { + AnomalyDetector detector = createAnomalyDetector(categoryFieldSize, resultIndex); + String detectorId = detector.getId(); + + // start historical detector + String taskId = startHistoricalAnalysis(detectorId); + + // get task profile + ADTaskProfile adTaskProfile = waitUntilGetTaskProfile(detectorId); + if (categoryFieldSize > 0) { + if (!ADTaskState.RUNNING.name().equals(adTaskProfile.getAdTask().getState())) { + adTaskProfile = (ADTaskProfile) waitUntilTaskReachState(detectorId, ImmutableSet.of(ADTaskState.RUNNING.name())).get(0); + } + assertEquals((int) Math.pow(categoryFieldDocCount, categoryFieldSize), adTaskProfile.getTotalEntitiesCount().intValue()); + assertTrue(adTaskProfile.getPendingEntitiesCount() > 0); + assertTrue(adTaskProfile.getRunningEntitiesCount() > 0); + } + ADTask adTask = adTaskProfile.getAdTask(); + assertEquals(taskId, adTask.getTaskId()); + assertTrue(TestHelpers.HISTORICAL_ANALYSIS_RUNNING_STATS.contains(adTask.getState())); + + // get task stats + Response statsResponse = TestHelpers.makeRequest(client(), "GET", AD_BASE_STATS_URI, ImmutableMap.of(), "", null); + String statsResult = EntityUtils.toString(statsResponse.getEntity()); + Map stringObjectMap = TestHelpers.parseStatsResult(statsResult); + String detectorCountState = categoryFieldSize > 0 ? MULTI_ENTITY_DETECTOR_COUNT.getName() : SINGLE_ENTITY_DETECTOR_COUNT.getName(); + assertTrue((long) stringObjectMap.get(detectorCountState) > 0); + Map nodes = (Map) stringObjectMap.get("nodes"); + long totalBatchTaskExecution = 0; + for (String key : nodes.keySet()) { + Map nodeStats = (Map) nodes.get(key); + totalBatchTaskExecution += (long) nodeStats.get(AD_TOTAL_BATCH_TASK_EXECUTION_COUNT.getName()); + } + assertTrue(totalBatchTaskExecution > 0); + + // get detector with AD task + ToXContentObject[] result = getHistoricalAnomalyDetector(detectorId, true, client()); + AnomalyDetector parsedDetector = (AnomalyDetector) result[0]; + AnomalyDetectorJob parsedJob = (AnomalyDetectorJob) result[1]; + ADTask parsedADTask = (ADTask) result[2]; + assertNull(parsedJob); + assertNotNull(parsedDetector); + assertNotNull(parsedADTask); + assertEquals(taskId, parsedADTask.getTaskId()); + + return ImmutableList.of(detectorId, taskId); + } + + @SuppressWarnings("unchecked") + public void testStopHistoricalAnalysis() throws Exception { + // create historical detector + AnomalyDetector detector = createAnomalyDetector(); + String detectorId = detector.getId(); + + // start historical detector + String taskId = startHistoricalAnalysis(detectorId); + + waitUntilGetTaskProfile(detectorId); + + // stop historical detector + Response stopDetectorResponse = stopAnomalyDetector(detectorId, client(), false); + assertEquals(RestStatus.OK, TestHelpers.restStatus(stopDetectorResponse)); + + // get task profile + checkIfTaskCanFinishCorrectly(detectorId, taskId, ImmutableSet.of(ADTaskState.STOPPED.name())); + updateClusterSettings(BATCH_TASK_PIECE_INTERVAL_SECONDS.getKey(), 1); + + waitUntilTaskDone(detectorId); + + // get AD stats + Response statsResponse = TestHelpers.makeRequest(client(), "GET", AD_BASE_STATS_URI, ImmutableMap.of(), "", null); + String statsResult = EntityUtils.toString(statsResponse.getEntity()); + Map stringObjectMap = TestHelpers.parseStatsResult(statsResult); + assertTrue((long) stringObjectMap.get("single_entity_detector_count") > 0); + Map nodes = (Map) stringObjectMap.get("nodes"); + long cancelledTaskCount = 0; + for (String key : nodes.keySet()) { + Map nodeStats = (Map) nodes.get(key); + cancelledTaskCount += (long) nodeStats.get("ad_canceled_batch_task_count"); + } + assertTrue(cancelledTaskCount >= 1); + } + + public void testUpdateHistoricalAnalysis() throws IOException, IllegalAccessException { + // create historical detector + AnomalyDetector detector = createAnomalyDetector(); + String detectorId = detector.getId(); + + // update historical detector + AnomalyDetector newDetector = randomAnomalyDetector(detector); + Response updateResponse = TestHelpers + .makeRequest( + client(), + "PUT", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detectorId + "?refresh=true", + ImmutableMap.of(), + TestHelpers.toHttpEntity(newDetector), + null + ); + Map responseBody = entityAsMap(updateResponse); + assertEquals(detector.getId(), responseBody.get("_id")); + assertEquals((detector.getVersion().intValue() + 1), (int) responseBody.get("_version")); + + // get historical detector + AnomalyDetector updatedDetector = getAnomalyDetector(detector.getId(), client()); + assertNotEquals(updatedDetector.getLastUpdateTime(), detector.getLastUpdateTime()); + assertEquals(newDetector.getName(), updatedDetector.getName()); + assertEquals(newDetector.getDescription(), updatedDetector.getDescription()); + } + + public void testUpdateRunningHistoricalAnalysis() throws Exception { + // create historical detector + AnomalyDetector detector = createAnomalyDetector(); + String detectorId = detector.getId(); + + // start historical detector + startHistoricalAnalysis(detectorId); + + // update historical detector + AnomalyDetector newDetector = randomAnomalyDetector(detector); + TestHelpers + .assertFailWith( + ResponseException.class, + "Detector is running", + () -> TestHelpers + .makeRequest( + client(), + "PUT", + TestHelpers.AD_BASE_DETECTORS_URI + "/" + detectorId + "?refresh=true", + ImmutableMap.of(), + TestHelpers.toHttpEntity(newDetector), + null + ) + ); + + waitUntilTaskDone(detectorId); + } + + // TODO: fix delete + public void testDeleteHistoricalAnalysis() throws IOException, IllegalAccessException { + // create historical detector + AnomalyDetector detector = createAnomalyDetector(); + String detectorId = detector.getId(); + + // delete detector + Response response = TestHelpers + .makeRequest(client(), "DELETE", TestHelpers.AD_BASE_DETECTORS_URI + "/" + detectorId, ImmutableMap.of(), "", null); + assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); + } + + // TODO: fix flaky test + @Ignore + public void testDeleteRunningHistoricalDetector() throws Exception { + // create historical detector + AnomalyDetector detector = createAnomalyDetector(); + String detectorId = detector.getId(); + + // start historical detector + startHistoricalAnalysis(detectorId); + + // delete detector + TestHelpers + .assertFailWith( + ResponseException.class, + "Detector is running", + () -> TestHelpers + .makeRequest(client(), "DELETE", TestHelpers.AD_BASE_DETECTORS_URI + "/" + detectorId, ImmutableMap.of(), "", null) + ); + + waitUntilTaskDone(detectorId); + } + + public void testSearchTasks() throws IOException, InterruptedException, IllegalAccessException, ParseException { + // create historical detector + AnomalyDetector detector = createAnomalyDetector(); + String detectorId = detector.getId(); + + // start historical detector + String taskId = startHistoricalAnalysis(detectorId); + + waitUntilTaskDone(detectorId); + + String query = String.format(Locale.ROOT, "{\"query\":{\"term\":{\"detector_id\":{\"value\":\"%s\"}}}}", detectorId); + Response response = TestHelpers + .makeRequest(client(), "POST", TestHelpers.AD_BASE_DETECTORS_URI + "/tasks/_search", ImmutableMap.of(), query, null); + String searchResult = EntityUtils.toString(response.getEntity()); + assertTrue(searchResult.contains(taskId)); + assertTrue(searchResult.contains(detector.getId())); + } + + private AnomalyDetector randomAnomalyDetector(AnomalyDetector detector) { + return new AnomalyDetector( + detector.getId(), + null, + randomAlphaOfLength(5), + randomAlphaOfLength(5), + detector.getTimeField(), + detector.getIndices(), + detector.getFeatureAttributes(), + detector.getFilterQuery(), + detector.getInterval(), + detector.getWindowDelay(), + detector.getShingleSize(), + detector.getUiMetadata(), + detector.getSchemaVersion(), + detector.getLastUpdateTime(), + detector.getCategoryFields(), + detector.getUser(), + detector.getCustomResultIndex(), + detector.getImputationOption() + ); + } + +} diff --git a/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java b/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java index e8ef7149c..1c0758ebf 100644 --- a/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java +++ b/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java @@ -35,7 +35,7 @@ import org.opensearch.client.RestClient; import org.opensearch.commons.authuser.User; import org.opensearch.commons.rest.SecureRestClientBuilder; -import org.opensearch.rest.RestStatus; +import org.opensearch.core.rest.RestStatus; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.model.DateRange; diff --git a/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java-e b/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java-e new file mode 100644 index 000000000..1c0758ebf --- /dev/null +++ b/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java-e @@ -0,0 +1,505 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Map; +import java.util.Random; + +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.HttpHost; +import org.apache.hc.core5.http.message.BasicHeader; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.opensearch.ad.AnomalyDetectorRestTestCase; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorExecutionInput; +import org.opensearch.client.Response; +import org.opensearch.client.RestClient; +import org.opensearch.commons.authuser.User; +import org.opensearch.commons.rest.SecureRestClientBuilder; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.model.DateRange; + +import com.google.common.collect.ImmutableList; + +public class SecureADRestIT extends AnomalyDetectorRestTestCase { + String aliceUser = "alice"; + RestClient aliceClient; + String bobUser = "bob"; + RestClient bobClient; + String catUser = "cat"; + RestClient catClient; + String dogUser = "dog"; + RestClient dogClient; + String elkUser = "elk"; + RestClient elkClient; + String fishUser = "fish"; + RestClient fishClient; + String goatUser = "goat"; + RestClient goatClient; + String lionUser = "lion"; + RestClient lionClient; + private String indexAllAccessRole = "index_all_access"; + private String indexSearchAccessRole = "index_all_search"; + + /** + * Create an unguessable password. Simple password are weak due to https://tinyurl.com/383em9zk + * @return a random password. + */ + public static String generatePassword() { + String characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; + + Random rng = new Random(); + + char[] password = new char[10]; + for (int i = 0; i < 10; i++) { + password[i] = characters.charAt(rng.nextInt(characters.length())); + } + + return new String(password); + } + + @Before + public void setupSecureTests() throws IOException { + if (!isHttps()) + throw new IllegalArgumentException("Secure Tests are running but HTTPS is not set"); + createIndexRole(indexAllAccessRole, "*"); + createSearchRole(indexSearchAccessRole, "*"); + String alicePassword = generatePassword(); + createUser(aliceUser, alicePassword, new ArrayList<>(Arrays.asList("odfe"))); + aliceClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), aliceUser, alicePassword) + .setSocketTimeout(60000) + .build(); + + String bobPassword = generatePassword(); + createUser(bobUser, bobPassword, new ArrayList<>(Arrays.asList("odfe"))); + bobClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), bobUser, bobPassword) + .setSocketTimeout(60000) + .build(); + + String catPassword = generatePassword(); + createUser(catUser, catPassword, new ArrayList<>(Arrays.asList("aes"))); + catClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), catUser, catPassword) + .setSocketTimeout(60000) + .build(); + + String dogPassword = generatePassword(); + createUser(dogUser, dogPassword, new ArrayList<>(Arrays.asList())); + dogClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), dogUser, dogPassword) + .setSocketTimeout(60000) + .build(); + + String elkPassword = generatePassword(); + createUser(elkUser, elkPassword, new ArrayList<>(Arrays.asList("odfe"))); + elkClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), elkUser, elkPassword) + .setSocketTimeout(60000) + .build(); + + String fishPassword = generatePassword(); + createUser(fishUser, fishPassword, new ArrayList<>(Arrays.asList("odfe", "aes"))); + fishClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), fishUser, fishPassword) + .setSocketTimeout(60000) + .build(); + + String goatPassword = generatePassword(); + createUser(goatUser, goatPassword, new ArrayList<>(Arrays.asList("opensearch"))); + goatClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), goatUser, goatPassword) + .setSocketTimeout(60000) + .build(); + + String lionPassword = generatePassword(); + createUser(lionUser, lionPassword, new ArrayList<>(Arrays.asList("opensearch"))); + lionClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), lionUser, lionPassword) + .setSocketTimeout(60000) + .build(); + + createRoleMapping("anomaly_read_access", new ArrayList<>(Arrays.asList(bobUser))); + createRoleMapping("anomaly_full_access", new ArrayList<>(Arrays.asList(aliceUser, catUser, dogUser, elkUser, fishUser, goatUser))); + createRoleMapping(indexAllAccessRole, new ArrayList<>(Arrays.asList(aliceUser, bobUser, catUser, dogUser, fishUser, lionUser))); + createRoleMapping(indexSearchAccessRole, new ArrayList<>(Arrays.asList(goatUser))); + } + + @After + public void deleteUserSetup() throws IOException { + aliceClient.close(); + bobClient.close(); + catClient.close(); + dogClient.close(); + elkClient.close(); + fishClient.close(); + goatClient.close(); + lionClient.close(); + deleteUser(aliceUser); + deleteUser(bobUser); + deleteUser(catUser); + deleteUser(dogUser); + deleteUser(elkUser); + deleteUser(fishUser); + deleteUser(goatUser); + deleteUser(lionUser); + } + + public void testCreateAnomalyDetectorWithWriteAccess() throws IOException { + // User Alice has AD full access, should be able to create a detector + AnomalyDetector aliceDetector = createRandomAnomalyDetector(false, false, aliceClient); + Assert.assertNotNull("User alice could not create detector", aliceDetector.getId()); + } + + public void testCreateAnomalyDetectorWithReadAccess() { + // User Bob has AD read access, should not be able to create a detector + Exception exception = expectThrows(IOException.class, () -> { createRandomAnomalyDetector(false, false, bobClient); }); + Assert.assertTrue(exception.getMessage().contains("no permissions for [cluster:admin/opendistro/ad/detector/write]")); + } + + public void testStartDetectorWithReadAccess() throws IOException { + // User Bob has AD read access, should not be able to modify a detector + AnomalyDetector aliceDetector = createRandomAnomalyDetector(false, false, aliceClient); + Assert.assertNotNull(aliceDetector.getId()); + Exception exception = expectThrows(IOException.class, () -> { startAnomalyDetector(aliceDetector.getId(), null, bobClient); }); + Assert.assertTrue(exception.getMessage().contains("no permissions for [cluster:admin/opendistro/ad/detector/jobmanagement]")); + } + + public void testStartDetectorForWriteUser() throws IOException { + // User Alice has AD full access, should be able to modify a detector + AnomalyDetector aliceDetector = createRandomAnomalyDetector(false, false, aliceClient); + Assert.assertNotNull(aliceDetector.getId()); + Instant now = Instant.now(); + Response response = startAnomalyDetector(aliceDetector.getId(), new DateRange(now.minus(10, ChronoUnit.DAYS), now), aliceClient); + MatcherAssert.assertThat(response.getStatusLine().toString(), CoreMatchers.containsString("200 OK")); + } + + public void testFilterByDisabled() throws IOException { + // User Alice has AD full access, should be able to create a detector + AnomalyDetector aliceDetector = createRandomAnomalyDetector(false, false, aliceClient); + // User Cat has AD full access, should be able to get a detector + AnomalyDetector detector = getAnomalyDetector(aliceDetector.getId(), catClient); + Assert.assertEquals(aliceDetector.getId(), detector.getId()); + } + + public void testGetApiFilterByEnabled() throws IOException { + // User Alice has AD full access, should be able to create a detector + AnomalyDetector aliceDetector = createRandomAnomalyDetector(false, false, aliceClient); + enableFilterBy(); + // User Cat has AD full access, but is part of different backend role so Cat should not be able to access + // Alice detector + Exception exception = expectThrows(IOException.class, () -> { getAnomalyDetector(aliceDetector.getId(), catClient); }); + Assert.assertTrue(exception.getMessage().contains("User does not have permissions to access detector: " + aliceDetector.getId())); + } + + private void confirmingClientIsAdmin() throws IOException { + Response resp = TestHelpers + .makeRequest( + client(), + "GET", + "_plugins/_security/api/account", + null, + "", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "admin")) + ); + Map responseMap = entityAsMap(resp); + ArrayList roles = (ArrayList) responseMap.get("roles"); + assertTrue(roles.contains("all_access")); + } + + public void testGetApiFilterByEnabledForAdmin() throws IOException { + // User Alice has AD full access, should be able to create a detector and has backend role "odfe" + AnomalyDetector aliceDetector = createRandomAnomalyDetector(false, false, aliceClient); + enableFilterBy(); + confirmingClientIsAdmin(); + AnomalyDetector detector = getAnomalyDetector(aliceDetector.getId(), client()); + Assert + .assertArrayEquals( + "User backend role of detector doesn't change", + new String[] { "odfe" }, + detector.getUser().getBackendRoles().toArray(new String[0]) + ); + } + + public void testUpdateApiFilterByEnabledForAdmin() throws IOException { + // User Alice has AD full access, should be able to create a detector and has backend role "odfe" + AnomalyDetector aliceDetector = createRandomAnomalyDetector(false, false, aliceClient); + enableFilterBy(); + + AnomalyDetector newDetector = new AnomalyDetector( + aliceDetector.getId(), + aliceDetector.getVersion(), + aliceDetector.getName(), + randomAlphaOfLength(10), + aliceDetector.getTimeField(), + aliceDetector.getIndices(), + aliceDetector.getFeatureAttributes(), + aliceDetector.getFilterQuery(), + aliceDetector.getInterval(), + aliceDetector.getWindowDelay(), + aliceDetector.getShingleSize(), + aliceDetector.getUiMetadata(), + aliceDetector.getSchemaVersion(), + Instant.now(), + aliceDetector.getCategoryFields(), + new User( + randomAlphaOfLength(5), + ImmutableList.of("odfe", randomAlphaOfLength(5)), + ImmutableList.of(randomAlphaOfLength(5)), + ImmutableList.of(randomAlphaOfLength(5)) + ), + null, + aliceDetector.getImputationOption() + ); + // User client has admin all access, and has "opensearch" backend role so client should be able to update detector + // But the detector's backend role should not be replaced as client's backend roles (all_access). + Response response = updateAnomalyDetector(aliceDetector.getId(), newDetector, client()); + Assert.assertEquals(response.getStatusLine().getStatusCode(), 200); + AnomalyDetector anomalyDetector = getAnomalyDetector(aliceDetector.getId(), aliceClient); + Assert + .assertArrayEquals( + "odfe is still the backendrole, not opensearch", + new String[] { "odfe" }, + anomalyDetector.getUser().getBackendRoles().toArray(new String[0]) + ); + } + + public void testUpdateApiFilterByEnabled() throws IOException { + // User Alice has AD full access, should be able to create a detector + AnomalyDetector aliceDetector = createRandomAnomalyDetector(false, false, aliceClient); + Assert + .assertArrayEquals( + "Wrong user roles", + new String[] { "odfe" }, + aliceDetector.getUser().getBackendRoles().toArray(new String[0]) + ); + AnomalyDetector newDetector = new AnomalyDetector( + aliceDetector.getId(), + aliceDetector.getVersion(), + aliceDetector.getName(), + randomAlphaOfLength(10), + aliceDetector.getTimeField(), + aliceDetector.getIndices(), + aliceDetector.getFeatureAttributes(), + aliceDetector.getFilterQuery(), + aliceDetector.getInterval(), + aliceDetector.getWindowDelay(), + aliceDetector.getShingleSize(), + aliceDetector.getUiMetadata(), + aliceDetector.getSchemaVersion(), + Instant.now(), + aliceDetector.getCategoryFields(), + new User( + randomAlphaOfLength(5), + ImmutableList.of("odfe", randomAlphaOfLength(5)), + ImmutableList.of(randomAlphaOfLength(5)), + ImmutableList.of(randomAlphaOfLength(5)) + ), + null, + aliceDetector.getImputationOption() + ); + enableFilterBy(); + // User Fish has AD full access, and has "odfe" backend role which is one of Alice's backend role, so + // Fish should be able to update detectors created by Alice. But the detector's backend role should + // not be replaced as Fish's backend roles. + Response response = updateAnomalyDetector(aliceDetector.getId(), newDetector, fishClient); + Assert.assertEquals(response.getStatusLine().getStatusCode(), 200); + AnomalyDetector anomalyDetector = getAnomalyDetector(aliceDetector.getId(), aliceClient); + Assert + .assertArrayEquals( + "Wrong user roles", + new String[] { "odfe" }, + anomalyDetector.getUser().getBackendRoles().toArray(new String[0]) + ); + } + + public void testStartApiFilterByEnabled() throws IOException { + // User Alice has AD full access, should be able to create a detector + AnomalyDetector aliceDetector = createRandomAnomalyDetector(false, false, aliceClient); + enableFilterBy(); + // User Cat has AD full access, but is part of different backend role so Cat should not be able to access + // Alice detector + Instant now = Instant.now(); + Exception exception = expectThrows( + IOException.class, + () -> { startAnomalyDetector(aliceDetector.getId(), new DateRange(now.minus(10, ChronoUnit.DAYS), now), catClient); } + ); + Assert.assertTrue(exception.getMessage().contains("User does not have permissions to access detector: " + aliceDetector.getId())); + } + + public void testStopApiFilterByEnabled() throws IOException { + // User Alice has AD full access, should be able to create a detector + AnomalyDetector aliceDetector = createRandomAnomalyDetector(false, false, aliceClient); + enableFilterBy(); + // User Cat has AD full access, but is part of different backend role so Cat should not be able to access + // Alice detector + Exception exception = expectThrows(IOException.class, () -> { stopAnomalyDetector(aliceDetector.getId(), catClient, true); }); + Assert.assertTrue(exception.getMessage().contains("User does not have permissions to access detector: " + aliceDetector.getId())); + } + + public void testDeleteApiFilterByEnabled() throws IOException { + // User Alice has AD full access, should be able to create a detector + AnomalyDetector aliceDetector = createRandomAnomalyDetector(false, false, aliceClient); + enableFilterBy(); + // User Cat has AD full access, but is part of different backend role so Cat should not be able to access + // Alice detector + Exception exception = expectThrows(IOException.class, () -> { deleteAnomalyDetector(aliceDetector.getId(), catClient); }); + Assert.assertTrue(exception.getMessage().contains("User does not have permissions to access detector: " + aliceDetector.getId())); + } + + public void testCreateAnomalyDetectorWithNoBackendRole() throws IOException { + enableFilterBy(); + // User Dog has AD full access, but has no backend role + // When filter by is enabled, we block creating Detectors + Exception exception = expectThrows(IOException.class, () -> { createRandomAnomalyDetector(false, false, dogClient); }); + Assert + .assertTrue( + exception.getMessage().contains("Filter by backend roles is enabled and User dog does not have backend roles configured") + ); + } + + public void testCreateAnomalyDetectorWithNoReadPermissionOfIndex() throws IOException { + enableFilterBy(); + // User alice has AD full access and index permission, so can create detector + AnomalyDetector anomalyDetector = createRandomAnomalyDetector(false, false, aliceClient); + // User elk has AD full access, but has no read permission of index + String indexName = anomalyDetector.getIndices().get(0); + Exception exception = expectThrows(IOException.class, () -> { createRandomAnomalyDetector(false, false, indexName, elkClient); }); + Assert.assertTrue(exception.getMessage().contains("no permissions for [indices:data/read/search]")); + } + + public void testCreateAnomalyDetectorWithCustomResultIndex() throws IOException { + // User alice has AD full access and index permission, so can create detector + AnomalyDetector anomalyDetector = createRandomAnomalyDetector(false, false, aliceClient); + // User elk has AD full access, but has no read permission of index + + String resultIndex = ADCommonName.CUSTOM_RESULT_INDEX_PREFIX + "test"; + AnomalyDetector detector = cloneDetector(anomalyDetector, resultIndex); + // User goat has no permission to create index + Exception exception = expectThrows(IOException.class, () -> { createAnomalyDetector(detector, true, goatClient); }); + Assert.assertTrue(exception.getMessage().contains("no permissions for [indices:admin/create]")); + + // User cat has permission to create index + resultIndex = ADCommonName.CUSTOM_RESULT_INDEX_PREFIX + "test2"; + TestHelpers.createIndexWithTimeField(client(), anomalyDetector.getIndices().get(0), anomalyDetector.getTimeField()); + AnomalyDetector detectorOfCat = createAnomalyDetector(cloneDetector(anomalyDetector, resultIndex), true, catClient); + assertEquals(resultIndex, detectorOfCat.getCustomResultIndex()); + } + + public void testPreviewAnomalyDetectorWithWriteAccess() throws IOException { + // User Alice has AD full access, should be able to create/preview a detector + AnomalyDetector aliceDetector = createRandomAnomalyDetector(false, false, aliceClient); + AnomalyDetectorExecutionInput input = new AnomalyDetectorExecutionInput( + aliceDetector.getId(), + Instant.now().minusSeconds(60 * 10), + Instant.now(), + null + ); + Response response = previewAnomalyDetector(aliceDetector.getId(), aliceClient, input); + Assert.assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); + } + + public void testPreviewAnomalyDetectorWithReadAccess() throws IOException { + // User Alice has AD full access, should be able to create a detector + AnomalyDetector aliceDetector = createRandomAnomalyDetector(false, false, aliceClient); + AnomalyDetectorExecutionInput input = new AnomalyDetectorExecutionInput( + randomAlphaOfLength(5), + Instant.now().minusSeconds(60 * 10), + Instant.now(), + null + ); + // User bob has AD read access, should not be able to preview a detector + Exception exception = expectThrows(IOException.class, () -> { previewAnomalyDetector(aliceDetector.getId(), bobClient, input); }); + Assert.assertTrue(exception.getMessage().contains("no permissions for [cluster:admin/opendistro/ad/detector/preview]")); + } + + public void testPreviewAnomalyDetectorWithFilterEnabled() throws IOException { + // User Alice has AD full access, should be able to create a detector + AnomalyDetector aliceDetector = createRandomAnomalyDetector(false, false, aliceClient); + AnomalyDetectorExecutionInput input = new AnomalyDetectorExecutionInput( + aliceDetector.getId(), + Instant.now().minusSeconds(60 * 10), + Instant.now(), + null + ); + enableFilterBy(); + // User Cat has AD full access, but is part of different backend role so Cat should not be able to access + // Alice detector + Exception exception = expectThrows(IOException.class, () -> { previewAnomalyDetector(aliceDetector.getId(), catClient, input); }); + Assert.assertTrue(exception.getMessage().contains("User does not have permissions to access detector: " + aliceDetector.getId())); + } + + public void testPreviewAnomalyDetectorWithNoReadPermissionOfIndex() throws IOException { + // User Alice has AD full access, should be able to create a detector + AnomalyDetector aliceDetector = createRandomAnomalyDetector(false, false, aliceClient); + AnomalyDetectorExecutionInput input = new AnomalyDetectorExecutionInput( + aliceDetector.getId(), + Instant.now().minusSeconds(60 * 10), + Instant.now(), + aliceDetector + ); + enableFilterBy(); + // User elk has no read permission of index + Exception exception = expectThrows(Exception.class, () -> { previewAnomalyDetector(aliceDetector.getId(), elkClient, input); }); + Assert + .assertTrue( + "actual msg: " + exception.getMessage(), + exception.getMessage().contains("no permissions for [indices:data/read/search]") + ); + } + + public void testValidateAnomalyDetectorWithWriteAccess() throws IOException { + // User Alice has AD full access, should be able to validate a detector + AnomalyDetector aliceDetector = createRandomAnomalyDetector(false, false, aliceClient); + Response validateResponse = validateAnomalyDetector(aliceDetector, aliceClient); + Assert.assertNotNull("User alice validated detector successfully", validateResponse); + } + + public void testValidateAnomalyDetectorWithNoADAccess() throws IOException { + // User Lion has no AD access at all, should not be able to validate a detector + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(null, Instant.now()); + Exception exception = expectThrows(IOException.class, () -> { validateAnomalyDetector(detector, lionClient); }); + Assert.assertTrue(exception.getMessage().contains("no permissions for [cluster:admin/opendistro/ad/detector/validate]")); + + } + + public void testValidateAnomalyDetectorWithReadAccess() throws IOException { + // User Bob has AD read access, should still be able to validate a detector + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(null, Instant.now()); + Response validateResponse = validateAnomalyDetector(detector, bobClient); + Assert.assertNotNull("User bob validated detector successfully", validateResponse); + } + + public void testValidateAnomalyDetectorWithNoReadPermissionOfIndex() throws IOException { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(null, Instant.now()); + enableFilterBy(); + // User elk has no read permission of index, can't validate detector + Exception exception = expectThrows(Exception.class, () -> { validateAnomalyDetector(detector, elkClient); }); + Assert.assertTrue(exception.getMessage().contains("no permissions for [indices:data/read/search]")); + } + + public void testValidateAnomalyDetectorWithNoBackendRole() throws IOException { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(null, Instant.now()); + enableFilterBy(); + // User Dog has AD full access, but has no backend role + // When filter by is enabled, we block validating Detectors + Exception exception = expectThrows(IOException.class, () -> { validateAnomalyDetector(detector, dogClient); }); + Assert + .assertTrue( + exception.getMessage().contains("Filter by backend roles is enabled and User dog does not have backend roles configured") + ); + } +} diff --git a/src/test/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandlerTests.java-e b/src/test/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandlerTests.java-e new file mode 100644 index 000000000..59eba777c --- /dev/null +++ b/src/test/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandlerTests.java-e @@ -0,0 +1,368 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.rest.handler; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.action.DocWriteResponse.Result.CREATED; +import static org.opensearch.ad.constant.ADCommonMessages.CAN_NOT_FIND_LATEST_TASK; + +import java.io.IOException; +import java.util.Arrays; + +import org.junit.Before; +import org.junit.BeforeClass; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.ad.ExecuteADResultResponseRecorder; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.mock.model.MockSimpleLog; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.task.ADTaskCacheManager; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.AnomalyDetectorJobResponse; +import org.opensearch.ad.transport.AnomalyResultAction; +import org.opensearch.ad.transport.AnomalyResultResponse; +import org.opensearch.ad.transport.ProfileAction; +import org.opensearch.ad.transport.ProfileResponse; +import org.opensearch.ad.transport.handler.AnomalyIndexHandler; +import org.opensearch.client.Client; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.common.exception.InternalFailure; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableList; + +public class IndexAnomalyDetectorJobActionHandlerTests extends OpenSearchTestCase { + + private static ADIndexManagement anomalyDetectionIndices; + private static String detectorId; + private static Long seqNo; + private static Long primaryTerm; + + private static NamedXContentRegistry xContentRegistry; + private static TransportService transportService; + private static TimeValue requestTimeout; + private static DiscoveryNodeFilterer nodeFilter; + private static AnomalyDetector detector; + + private ADTaskManager adTaskManager; + + private ThreadPool threadPool; + + private ExecuteADResultResponseRecorder recorder; + private Client client; + private IndexAnomalyDetectorJobActionHandler handler; + private AnomalyIndexHandler anomalyResultHandler; + private NodeStateManager nodeStateManager; + private ADTaskCacheManager adTaskCacheManager; + + @BeforeClass + public static void setOnce() throws IOException { + detectorId = "123"; + seqNo = 1L; + primaryTerm = 2L; + anomalyDetectionIndices = mock(ADIndexManagement.class); + xContentRegistry = NamedXContentRegistry.EMPTY; + transportService = mock(TransportService.class); + + requestTimeout = TimeValue.timeValueMinutes(60); + when(anomalyDetectionIndices.doesJobIndexExist()).thenReturn(true); + + nodeFilter = mock(DiscoveryNodeFilterer.class); + detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList("a")); + } + + @SuppressWarnings("unchecked") + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + client = mock(Client.class); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(false); + listener.onResponse(response); + + return null; + }).when(client).get(any(GetRequest.class), any()); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + + IndexResponse response = mock(IndexResponse.class); + when(response.getResult()).thenReturn(CREATED); + listener.onResponse(response); + + return null; + }).when(client).index(any(IndexRequest.class), any()); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[2]; + + AnomalyResultResponse response = new AnomalyResultResponse(null, "", 0L, 10L, true); + listener.onResponse(response); + + return null; + }).when(client).execute(any(AnomalyResultAction.class), any(), any()); + + adTaskManager = mock(ADTaskManager.class); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[4]; + + AnomalyDetectorJobResponse response = mock(AnomalyDetectorJobResponse.class); + listener.onResponse(response); + + return null; + }).when(adTaskManager).startDetector(any(), any(), any(), any(), any()); + + threadPool = mock(ThreadPool.class); + + anomalyResultHandler = mock(AnomalyIndexHandler.class); + + nodeStateManager = mock(NodeStateManager.class); + + adTaskCacheManager = mock(ADTaskCacheManager.class); + when(adTaskCacheManager.hasQueriedResultIndex(anyString())).thenReturn(true); + + recorder = new ExecuteADResultResponseRecorder( + anomalyDetectionIndices, + anomalyResultHandler, + adTaskManager, + nodeFilter, + threadPool, + client, + nodeStateManager, + adTaskCacheManager, + 32 + ); + + handler = new IndexAnomalyDetectorJobActionHandler( + client, + anomalyDetectionIndices, + detectorId, + seqNo, + primaryTerm, + requestTimeout, + xContentRegistry, + transportService, + adTaskManager, + recorder + ); + } + + @SuppressWarnings("unchecked") + public void testDelayHCProfile() { + when(adTaskManager.isHCRealtimeTaskStartInitializing(anyString())).thenReturn(false); + + ActionListener listener = mock(ActionListener.class); + + handler.startAnomalyDetectorJob(detector, listener); + + verify(client, times(1)).get(any(), any()); + verify(client, times(1)).execute(any(), any(), any()); + verify(adTaskManager, times(1)).startDetector(any(), any(), any(), any(), any()); + verify(adTaskManager, times(1)).isHCRealtimeTaskStartInitializing(anyString()); + verify(threadPool, times(1)).schedule(any(), any(), any()); + verify(listener, times(1)).onResponse(any()); + } + + @SuppressWarnings("unchecked") + public void testNoDelayHCProfile() { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[2]; + + ProfileResponse response = mock(ProfileResponse.class); + when(response.getTotalUpdates()).thenReturn(3L); + listener.onResponse(response); + + return null; + }).when(client).execute(any(ProfileAction.class), any(), any()); + + when(adTaskManager.isHCRealtimeTaskStartInitializing(anyString())).thenReturn(true); + + ActionListener listener = mock(ActionListener.class); + + handler.startAnomalyDetectorJob(detector, listener); + + verify(client, times(1)).get(any(), any()); + verify(client, times(2)).execute(any(), any(), any()); + verify(adTaskManager, times(1)).startDetector(any(), any(), any(), any(), any()); + verify(adTaskManager, times(1)).isHCRealtimeTaskStartInitializing(anyString()); + verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); + verify(threadPool, never()).schedule(any(), any(), any()); + verify(listener, times(1)).onResponse(any()); + } + + @SuppressWarnings("unchecked") + public void testHCProfileException() { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[2]; + + listener.onFailure(new RuntimeException()); + + return null; + }).when(client).execute(any(ProfileAction.class), any(), any()); + + when(adTaskManager.isHCRealtimeTaskStartInitializing(anyString())).thenReturn(true); + + ActionListener listener = mock(ActionListener.class); + + handler.startAnomalyDetectorJob(detector, listener); + + verify(client, times(1)).get(any(), any()); + verify(client, times(2)).execute(any(), any(), any()); + verify(adTaskManager, times(1)).startDetector(any(), any(), any(), any(), any()); + verify(adTaskManager, times(1)).isHCRealtimeTaskStartInitializing(anyString()); + verify(adTaskManager, never()).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); + verify(threadPool, never()).schedule(any(), any(), any()); + verify(listener, times(1)).onResponse(any()); + } + + @SuppressWarnings("unchecked") + public void testUpdateLatestRealtimeTaskOnCoordinatingNodeResourceNotFoundException() { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[2]; + + ProfileResponse response = mock(ProfileResponse.class); + when(response.getTotalUpdates()).thenReturn(3L); + listener.onResponse(response); + + return null; + }).when(client).execute(any(ProfileAction.class), any(), any()); + + when(adTaskManager.isHCRealtimeTaskStartInitializing(anyString())).thenReturn(true); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[5]; + + listener.onFailure(new ResourceNotFoundException(CAN_NOT_FIND_LATEST_TASK)); + + return null; + }).when(adTaskManager).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); + + ActionListener listener = mock(ActionListener.class); + + handler.startAnomalyDetectorJob(detector, listener); + + verify(client, times(1)).get(any(), any()); + verify(client, times(2)).execute(any(), any(), any()); + verify(adTaskManager, times(1)).startDetector(any(), any(), any(), any(), any()); + verify(adTaskManager, times(1)).isHCRealtimeTaskStartInitializing(anyString()); + verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); + verify(adTaskManager, times(1)).removeRealtimeTaskCache(anyString()); + verify(threadPool, never()).schedule(any(), any(), any()); + verify(listener, times(1)).onResponse(any()); + } + + @SuppressWarnings("unchecked") + public void testUpdateLatestRealtimeTaskOnCoordinatingException() { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[2]; + + ProfileResponse response = mock(ProfileResponse.class); + when(response.getTotalUpdates()).thenReturn(3L); + listener.onResponse(response); + + return null; + }).when(client).execute(any(ProfileAction.class), any(), any()); + + when(adTaskManager.isHCRealtimeTaskStartInitializing(anyString())).thenReturn(true); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[5]; + + listener.onFailure(new RuntimeException()); + + return null; + }).when(adTaskManager).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); + + ActionListener listener = mock(ActionListener.class); + + handler.startAnomalyDetectorJob(detector, listener); + + verify(client, times(1)).get(any(), any()); + verify(client, times(2)).execute(any(), any(), any()); + verify(adTaskManager, times(1)).startDetector(any(), any(), any(), any(), any()); + verify(adTaskManager, times(1)).isHCRealtimeTaskStartInitializing(anyString()); + verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); + verify(adTaskManager, never()).removeRealtimeTaskCache(anyString()); + verify(adTaskManager, times(1)).skipUpdateHCRealtimeTask(anyString(), anyString()); + verify(threadPool, never()).schedule(any(), any(), any()); + verify(listener, times(1)).onResponse(any()); + } + + @SuppressWarnings("unchecked") + public void testIndexException() throws IOException { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[2]; + + listener.onFailure(new InternalFailure(detectorId, ADCommonMessages.NO_MODEL_ERR_MSG)); + + return null; + }).when(client).execute(any(AnomalyResultAction.class), any(), any()); + + ActionListener listener = mock(ActionListener.class); + AggregationBuilder aggregationBuilder = TestHelpers + .parseAggregation("{\"test\":{\"max\":{\"field\":\"" + MockSimpleLog.VALUE_FIELD + "\"}}}"); + Feature feature = new Feature(randomAlphaOfLength(5), randomAlphaOfLength(10), true, aggregationBuilder); + detector = TestHelpers + .randomDetector( + ImmutableList.of(feature), + "test", + 10, + MockSimpleLog.TIME_FIELD, + null, + ADCommonName.CUSTOM_RESULT_INDEX_PREFIX + "index" + ); + when(anomalyDetectionIndices.doesIndexExist(anyString())).thenReturn(false); + handler.startAnomalyDetectorJob(detector, listener); + verify(anomalyResultHandler, times(1)).index(any(), any(), eq(null)); + verify(threadPool, times(1)).schedule(any(), any(), any()); + } +} diff --git a/src/test/java/org/opensearch/ad/settings/ADEnabledSettingTests.java-e b/src/test/java/org/opensearch/ad/settings/ADEnabledSettingTests.java-e new file mode 100644 index 000000000..6de90a068 --- /dev/null +++ b/src/test/java/org/opensearch/ad/settings/ADEnabledSettingTests.java-e @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.settings; + +import static org.mockito.Mockito.mock; +import static org.opensearch.common.settings.Setting.Property.Dynamic; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.mockito.Mockito; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.test.OpenSearchTestCase; + +public class ADEnabledSettingTests extends OpenSearchTestCase { + + public void testIsADEnabled() { + assertTrue(ADEnabledSetting.isADEnabled()); + ADEnabledSetting.getInstance().setSettingValue(ADEnabledSetting.AD_ENABLED, false); + assertTrue(!ADEnabledSetting.isADEnabled()); + ADEnabledSetting.getInstance().setSettingValue(ADEnabledSetting.AD_ENABLED, true); + } + + public void testIsADBreakerEnabled() { + assertTrue(ADEnabledSetting.isADBreakerEnabled()); + ADEnabledSetting.getInstance().setSettingValue(ADEnabledSetting.AD_BREAKER_ENABLED, false); + assertTrue(!ADEnabledSetting.isADBreakerEnabled()); + } + + public void testIsInterpolationInColdStartEnabled() { + assertTrue(!ADEnabledSetting.isInterpolationInColdStartEnabled()); + ADEnabledSetting.getInstance().setSettingValue(ADEnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, true); + assertTrue(ADEnabledSetting.isInterpolationInColdStartEnabled()); + } + + public void testIsDoorKeeperInCacheEnabled() { + assertTrue(!ADEnabledSetting.isDoorKeeperInCacheEnabled()); + ADEnabledSetting.getInstance().setSettingValue(ADEnabledSetting.DOOR_KEEPER_IN_CACHE_ENABLED, true); + assertTrue(ADEnabledSetting.isDoorKeeperInCacheEnabled()); + } + + public void testSetSettingsUpdateConsumers() { + Setting testSetting = Setting.boolSetting("test.setting", true, Setting.Property.NodeScope, Dynamic); + Map> settings = new HashMap<>(); + settings.put("test.setting", testSetting); + ADEnabledSetting dynamicNumericSetting = new ADEnabledSetting(settings); + ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, Collections.singleton(testSetting)); + ClusterService clusterService = mock(ClusterService.class); + Mockito.when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + dynamicNumericSetting.init(clusterService); + + assertEquals(true, dynamicNumericSetting.getSettingValue("test.setting")); + } + + public void testGetSettings() { + Setting testSetting1 = Setting.boolSetting("test.setting1", true, Setting.Property.NodeScope); + Setting testSetting2 = Setting.boolSetting("test.setting2", false, Setting.Property.NodeScope); + Map> settings = new HashMap<>(); + settings.put("test.setting1", testSetting1); + settings.put("test.setting2", testSetting2); + ADEnabledSetting dynamicNumericSetting = new ADEnabledSetting(settings); + List> returnedSettings = dynamicNumericSetting.getSettings(); + assertEquals(2, returnedSettings.size()); + assertTrue(returnedSettings.containsAll(settings.values())); + } +} diff --git a/src/test/java/org/opensearch/ad/settings/ADNumericSettingTests.java-e b/src/test/java/org/opensearch/ad/settings/ADNumericSettingTests.java-e new file mode 100644 index 000000000..71d131641 --- /dev/null +++ b/src/test/java/org/opensearch/ad/settings/ADNumericSettingTests.java-e @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.settings; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.opensearch.common.settings.Setting; +import org.opensearch.test.OpenSearchTestCase; + +public class ADNumericSettingTests extends OpenSearchTestCase { + private ADNumericSetting adSetting; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + adSetting = ADNumericSetting.getInstance(); + } + + public void testMaxCategoricalFields() { + adSetting.setSettingValue(ADNumericSetting.CATEGORY_FIELD_LIMIT, 3); + int value = ADNumericSetting.maxCategoricalFields(); + assertEquals("Expected value is 3", 3, value); + } + + public void testGetSettingValue() { + Map> settingsMap = new HashMap<>(); + Setting testSetting = Setting.intSetting("test.setting", 1, Setting.Property.NodeScope); + settingsMap.put("test.setting", testSetting); + adSetting = new ADNumericSetting(settingsMap); + + adSetting.setSettingValue("test.setting", 2); + Integer value = adSetting.getSettingValue("test.setting"); + assertEquals("Expected value is 2", 2, value.intValue()); + } + + public void testGetSettingNonexistentKey() { + try { + adSetting.getSettingValue("nonexistent.key"); + fail("Expected an IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + assertEquals("Cannot find setting by key [nonexistent.key]", e.getMessage()); + } + } +} diff --git a/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java b/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java index 9ee9a5d37..72e336ea7 100644 --- a/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java +++ b/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java @@ -15,19 +15,19 @@ import java.util.List; import org.junit.Before; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; @SuppressWarnings({ "rawtypes" }) public class AnomalyDetectorSettingsTests extends OpenSearchTestCase { - AnomalyDetectorPlugin plugin; + TimeSeriesAnalyticsPlugin plugin; @Before public void setup() { - this.plugin = new AnomalyDetectorPlugin(); + this.plugin = new TimeSeriesAnalyticsPlugin(); } public void testAllLegacyOpenDistroSettingsReturned() { diff --git a/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java-e b/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java-e new file mode 100644 index 000000000..72e336ea7 --- /dev/null +++ b/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java-e @@ -0,0 +1,410 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.settings; + +import java.util.Arrays; +import java.util.List; + +import org.junit.Before; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +@SuppressWarnings({ "rawtypes" }) +public class AnomalyDetectorSettingsTests extends OpenSearchTestCase { + TimeSeriesAnalyticsPlugin plugin; + + @Before + public void setup() { + this.plugin = new TimeSeriesAnalyticsPlugin(); + } + + public void testAllLegacyOpenDistroSettingsReturned() { + List> settings = plugin.getSettings(); + assertTrue( + "legacy setting must be returned from settings", + settings + .containsAll( + Arrays + .asList( + LegacyOpenDistroAnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS, + LegacyOpenDistroAnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS, + LegacyOpenDistroAnomalyDetectorSettings.MAX_ANOMALY_FEATURES, + LegacyOpenDistroAnomalyDetectorSettings.REQUEST_TIMEOUT, + LegacyOpenDistroAnomalyDetectorSettings.DETECTION_INTERVAL, + LegacyOpenDistroAnomalyDetectorSettings.DETECTION_WINDOW_DELAY, + LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD, + LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS, + LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE, + LegacyOpenDistroAnomalyDetectorSettings.COOLDOWN_MINUTES, + LegacyOpenDistroAnomalyDetectorSettings.BACKOFF_MINUTES, + LegacyOpenDistroAnomalyDetectorSettings.BACKOFF_INITIAL_DELAY, + LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_BACKOFF, + LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD, + LegacyOpenDistroAnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE, + LegacyOpenDistroAnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, + LegacyOpenDistroAnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW, + LegacyOpenDistroAnomalyDetectorSettings.INDEX_PRESSURE_SOFT_LIMIT, + LegacyOpenDistroAnomalyDetectorSettings.MAX_PRIMARY_SHARDS, + LegacyOpenDistroAnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES, + LegacyOpenDistroAnomalyDetectorSettings.MAX_CACHE_MISS_HANDLING_PER_SECOND, + LegacyOpenDistroAnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE, + LegacyOpenDistroAnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS, + LegacyOpenDistroAnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR, + LegacyOpenDistroAnomalyDetectorSettings.BATCH_TASK_PIECE_SIZE + ) + ) + ); + } + + public void testAllOpenSearchSettingsReturned() { + List> settings = plugin.getSettings(); + assertTrue( + "opensearch setting must be returned from settings", + settings + .containsAll( + Arrays + .asList( + AnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS, + AnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS, + AnomalyDetectorSettings.MAX_ANOMALY_FEATURES, + AnomalyDetectorSettings.REQUEST_TIMEOUT, + AnomalyDetectorSettings.DETECTION_INTERVAL, + AnomalyDetectorSettings.DETECTION_WINDOW_DELAY, + AnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD, + AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD, + AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE, + AnomalyDetectorSettings.COOLDOWN_MINUTES, + AnomalyDetectorSettings.BACKOFF_MINUTES, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF, + AnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD, + AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE, + AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, + AnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW, + AnomalyDetectorSettings.AD_INDEX_PRESSURE_SOFT_LIMIT, + AnomalyDetectorSettings.AD_INDEX_PRESSURE_HARD_LIMIT, + AnomalyDetectorSettings.AD_MAX_PRIMARY_SHARDS, + AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES, + AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE, + AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS, + AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR, + AnomalyDetectorSettings.BATCH_TASK_PIECE_SIZE, + AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE, + AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, + AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_BATCH_SIZE, + AnomalyDetectorSettings.DEDICATED_CACHE_SIZE, + AnomalyDetectorSettings.COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, + AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, + AnomalyDetectorSettings.PAGE_SIZE + ) + ) + ); + } + + public void testAllLegacyOpenDistroSettingsFallback() { + assertEquals( + AnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS.get(Settings.EMPTY), + LegacyOpenDistroAnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS.get(Settings.EMPTY) + ); + assertEquals( + AnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS.get(Settings.EMPTY), + LegacyOpenDistroAnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS.get(Settings.EMPTY) + ); + assertEquals( + AnomalyDetectorSettings.MAX_ANOMALY_FEATURES.get(Settings.EMPTY), + LegacyOpenDistroAnomalyDetectorSettings.MAX_ANOMALY_FEATURES.get(Settings.EMPTY) + ); + assertEquals( + AnomalyDetectorSettings.REQUEST_TIMEOUT.get(Settings.EMPTY), + LegacyOpenDistroAnomalyDetectorSettings.REQUEST_TIMEOUT.get(Settings.EMPTY) + ); + assertEquals( + AnomalyDetectorSettings.DETECTION_INTERVAL.get(Settings.EMPTY), + LegacyOpenDistroAnomalyDetectorSettings.DETECTION_INTERVAL.get(Settings.EMPTY) + ); + assertEquals( + AnomalyDetectorSettings.DETECTION_WINDOW_DELAY.get(Settings.EMPTY), + LegacyOpenDistroAnomalyDetectorSettings.DETECTION_WINDOW_DELAY.get(Settings.EMPTY) + ); + assertEquals( + AnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD.get(Settings.EMPTY), + LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD.get(Settings.EMPTY) + ); + assertEquals( + AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(Settings.EMPTY), + LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(Settings.EMPTY) + ); + assertEquals( + AnomalyDetectorSettings.COOLDOWN_MINUTES.get(Settings.EMPTY), + LegacyOpenDistroAnomalyDetectorSettings.COOLDOWN_MINUTES.get(Settings.EMPTY) + ); + assertEquals( + AnomalyDetectorSettings.BACKOFF_MINUTES.get(Settings.EMPTY), + LegacyOpenDistroAnomalyDetectorSettings.BACKOFF_MINUTES.get(Settings.EMPTY) + ); + assertEquals( + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY.get(Settings.EMPTY), + LegacyOpenDistroAnomalyDetectorSettings.BACKOFF_INITIAL_DELAY.get(Settings.EMPTY) + ); + assertEquals( + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF.get(Settings.EMPTY), + LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_BACKOFF.get(Settings.EMPTY) + ); + assertEquals( + AnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD.get(Settings.EMPTY), + LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD.get(Settings.EMPTY) + ); + assertEquals( + AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE.get(Settings.EMPTY), + LegacyOpenDistroAnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE.get(Settings.EMPTY) + ); + // MAX_ENTITIES_FOR_PREVIEW does not use legacy setting + assertEquals(Integer.valueOf(5), AnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW.get(Settings.EMPTY)); + // INDEX_PRESSURE_SOFT_LIMIT does not use legacy setting + assertEquals(Float.valueOf(0.6f), AnomalyDetectorSettings.AD_INDEX_PRESSURE_SOFT_LIMIT.get(Settings.EMPTY)); + assertEquals( + AnomalyDetectorSettings.AD_MAX_PRIMARY_SHARDS.get(Settings.EMPTY), + LegacyOpenDistroAnomalyDetectorSettings.MAX_PRIMARY_SHARDS.get(Settings.EMPTY) + ); + assertEquals( + AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(Settings.EMPTY), + LegacyOpenDistroAnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(Settings.EMPTY) + ); + assertEquals( + AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE.get(Settings.EMPTY), + LegacyOpenDistroAnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE.get(Settings.EMPTY) + ); + assertEquals( + AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS.get(Settings.EMPTY), + LegacyOpenDistroAnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS.get(Settings.EMPTY) + ); + assertEquals( + AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR.get(Settings.EMPTY), + LegacyOpenDistroAnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR.get(Settings.EMPTY) + ); + assertEquals( + AnomalyDetectorSettings.BATCH_TASK_PIECE_SIZE.get(Settings.EMPTY), + LegacyOpenDistroAnomalyDetectorSettings.BATCH_TASK_PIECE_SIZE.get(Settings.EMPTY) + ); + } + + public void testSettingsGetValue() { + Settings settings = Settings.builder().put("plugins.anomaly_detection.request_timeout", "42s").build(); + assertEquals(AnomalyDetectorSettings.REQUEST_TIMEOUT.get(settings), TimeValue.timeValueSeconds(42)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.REQUEST_TIMEOUT.get(settings), TimeValue.timeValueSeconds(10)); + + settings = Settings.builder().put("plugins.anomaly_detection.max_anomaly_detectors", 99).build(); + assertEquals(AnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS.get(settings), Integer.valueOf(99)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS.get(settings), Integer.valueOf(1000)); + + settings = Settings.builder().put("plugins.anomaly_detection.max_multi_entity_anomaly_detectors", 98).build(); + assertEquals(AnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS.get(settings), Integer.valueOf(98)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS.get(settings), Integer.valueOf(10)); + + settings = Settings.builder().put("plugins.anomaly_detection.max_anomaly_features", 7).build(); + assertEquals(AnomalyDetectorSettings.MAX_ANOMALY_FEATURES.get(settings), Integer.valueOf(7)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MAX_ANOMALY_FEATURES.get(settings), Integer.valueOf(5)); + + settings = Settings.builder().put("plugins.anomaly_detection.detection_interval", TimeValue.timeValueMinutes(96)).build(); + assertEquals(AnomalyDetectorSettings.DETECTION_INTERVAL.get(settings), TimeValue.timeValueMinutes(96)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.DETECTION_INTERVAL.get(settings), TimeValue.timeValueMinutes(10)); + + settings = Settings.builder().put("plugins.anomaly_detection.detection_window_delay", TimeValue.timeValueMinutes(95)).build(); + assertEquals(AnomalyDetectorSettings.DETECTION_WINDOW_DELAY.get(settings), TimeValue.timeValueMinutes(95)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.DETECTION_WINDOW_DELAY.get(settings), TimeValue.timeValueMinutes(0)); + + settings = Settings + .builder() + .put("plugins.anomaly_detection.ad_result_history_rollover_period", TimeValue.timeValueHours(94)) + .build(); + assertEquals(AnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD.get(settings), TimeValue.timeValueHours(94)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD.get(settings), TimeValue.timeValueHours(12)); + + settings = Settings.builder().put("plugins.anomaly_detection.ad_result_history_max_docs_per_shard", 93).build(); + assertEquals(AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD.get(settings), Long.valueOf(93)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS.get(settings), Long.valueOf(250000000)); + + settings = Settings + .builder() + .put("plugins.anomaly_detection.ad_result_history_retention_period", TimeValue.timeValueDays(92)) + .build(); + assertEquals(AnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD.get(settings), TimeValue.timeValueDays(92)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD.get(settings), TimeValue.timeValueDays(30)); + + settings = Settings.builder().put("plugins.anomaly_detection.max_retry_for_unresponsive_node", 91).build(); + assertEquals(AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(settings), Integer.valueOf(91)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(settings), Integer.valueOf(5)); + + settings = Settings.builder().put("plugins.anomaly_detection.cooldown_minutes", TimeValue.timeValueMinutes(90)).build(); + assertEquals(AnomalyDetectorSettings.COOLDOWN_MINUTES.get(settings), TimeValue.timeValueMinutes(90)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.COOLDOWN_MINUTES.get(settings), TimeValue.timeValueMinutes(5)); + + settings = Settings.builder().put("plugins.anomaly_detection.backoff_minutes", TimeValue.timeValueMinutes(89)).build(); + assertEquals(AnomalyDetectorSettings.BACKOFF_MINUTES.get(settings), TimeValue.timeValueMinutes(89)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.BACKOFF_MINUTES.get(settings), TimeValue.timeValueMinutes(15)); + + settings = Settings.builder().put("plugins.anomaly_detection.backoff_initial_delay", TimeValue.timeValueMillis(88)).build(); + assertEquals(AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY.get(settings), TimeValue.timeValueMillis(88)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.BACKOFF_INITIAL_DELAY.get(settings), TimeValue.timeValueMillis(1000)); + + settings = Settings.builder().put("plugins.anomaly_detection.max_retry_for_backoff", 87).build(); + assertEquals(AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF.get(settings), Integer.valueOf(87)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_BACKOFF.get(settings), Integer.valueOf(3)); + + settings = Settings.builder().put("plugins.anomaly_detection.max_retry_for_end_run_exception", 86).build(); + assertEquals(AnomalyDetectorSettings.MAX_RETRY_FOR_END_RUN_EXCEPTION.get(settings), Integer.valueOf(86)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_END_RUN_EXCEPTION.get(settings), Integer.valueOf(6)); + + settings = Settings.builder().put("plugins.anomaly_detection.filter_by_backend_roles", true).build(); + assertEquals(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(settings), Boolean.valueOf(true)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(settings), Boolean.valueOf(false)); + + settings = Settings.builder().put("plugins.anomaly_detection.model_max_size_percent", 0.3).build(); + assertEquals(AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE.get(settings), Double.valueOf(0.3)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE.get(settings), Double.valueOf(0.1)); + + settings = Settings.builder().put("plugins.anomaly_detection.max_entities_per_query", 83).build(); + assertEquals(AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY.get(settings), Integer.valueOf(83)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY.get(settings), Integer.valueOf(1000)); + + settings = Settings.builder().put("plugins.anomaly_detection.max_entities_for_preview", 22).build(); + assertEquals(AnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW.get(settings), Integer.valueOf(22)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW.get(settings), Integer.valueOf(30)); + + settings = Settings.builder().put("plugins.anomaly_detection.index_pressure_soft_limit", 81f).build(); + assertEquals(AnomalyDetectorSettings.AD_INDEX_PRESSURE_SOFT_LIMIT.get(settings), Float.valueOf(81f)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.INDEX_PRESSURE_SOFT_LIMIT.get(settings), Float.valueOf(0.8f)); + + settings = Settings.builder().put("plugins.anomaly_detection.max_primary_shards", 80).build(); + assertEquals(AnomalyDetectorSettings.AD_MAX_PRIMARY_SHARDS.get(settings), Integer.valueOf(80)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MAX_PRIMARY_SHARDS.get(settings), Integer.valueOf(10)); + + settings = Settings.builder().put("plugins.anomaly_detection.max_cache_miss_handling_per_second", 79).build(); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MAX_CACHE_MISS_HANDLING_PER_SECOND.get(settings), Integer.valueOf(100)); + + settings = Settings.builder().put("plugins.anomaly_detection.max_batch_task_per_node", 78).build(); + assertEquals(AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE.get(settings), Integer.valueOf(78)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE.get(settings), Integer.valueOf(10)); + + settings = Settings.builder().put("plugins.anomaly_detection.max_old_ad_task_docs_per_detector", 77).build(); + assertEquals(AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR.get(settings), Integer.valueOf(77)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR.get(settings), Integer.valueOf(1)); + + settings = Settings.builder().put("plugins.anomaly_detection.batch_task_piece_size", 76).build(); + assertEquals(AnomalyDetectorSettings.BATCH_TASK_PIECE_SIZE.get(settings), Integer.valueOf(76)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.BATCH_TASK_PIECE_SIZE.get(settings), Integer.valueOf(1000)); + + settings = Settings.builder().put("plugins.anomaly_detection.batch_task_piece_interval_seconds", 76).build(); + assertEquals(AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS.get(settings), Integer.valueOf(76)); + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS.get(settings), Integer.valueOf(5)); + } + + public void testSettingsGetValueWithLegacyFallback() { + Settings settings = Settings + .builder() + .put("opendistro.anomaly_detection.max_anomaly_detectors", 1) + .put("opendistro.anomaly_detection.max_multi_entity_anomaly_detectors", 2) + .put("opendistro.anomaly_detection.max_anomaly_features", 3) + .put("opendistro.anomaly_detection.request_timeout", "4s") + .put("opendistro.anomaly_detection.detection_interval", "5m") + .put("opendistro.anomaly_detection.detection_window_delay", "6m") + .put("opendistro.anomaly_detection.ad_result_history_rollover_period", "7h") + .put("opendistro.anomaly_detection.ad_result_history_max_docs", 8L) + .put("opendistro.anomaly_detection.ad_result_history_retention_period", "9d") + .put("opendistro.anomaly_detection.max_retry_for_unresponsive_node", 10) + .put("opendistro.anomaly_detection.cooldown_minutes", "11m") + .put("opendistro.anomaly_detection.backoff_minutes", "12m") + .put("opendistro.anomaly_detection.backoff_initial_delay", "13ms") // + .put("opendistro.anomaly_detection.max_retry_for_backoff", 14) + .put("opendistro.anomaly_detection.max_retry_for_end_run_exception", 15) + .put("opendistro.anomaly_detection.filter_by_backend_roles", true) + .put("opendistro.anomaly_detection.model_max_size_percent", 0.6D) + .put("opendistro.anomaly_detection.max_entities_for_preview", 19) + .put("opendistro.anomaly_detection.index_pressure_soft_limit", 20F) + .put("opendistro.anomaly_detection.max_primary_shards", 21) + .put("opendistro.anomaly_detection.max_cache_miss_handling_per_second", 22) + .put("opendistro.anomaly_detection.max_batch_task_per_node", 23) + .put("opendistro.anomaly_detection.max_old_ad_task_docs_per_detector", 24) + .put("opendistro.anomaly_detection.batch_task_piece_size", 25) + .put("opendistro.anomaly_detection.batch_task_piece_interval_seconds", 26) + .build(); + + assertEquals(AnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS.get(settings), Integer.valueOf(1)); + assertEquals(AnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS.get(settings), Integer.valueOf(2)); + assertEquals(AnomalyDetectorSettings.MAX_ANOMALY_FEATURES.get(settings), Integer.valueOf(3)); + assertEquals(AnomalyDetectorSettings.REQUEST_TIMEOUT.get(settings), TimeValue.timeValueSeconds(4)); + assertEquals(AnomalyDetectorSettings.DETECTION_INTERVAL.get(settings), TimeValue.timeValueMinutes(5)); + assertEquals(AnomalyDetectorSettings.DETECTION_WINDOW_DELAY.get(settings), TimeValue.timeValueMinutes(6)); + assertEquals(AnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD.get(settings), TimeValue.timeValueHours(7)); + // AD_RESULT_HISTORY_MAX_DOCS is removed in the new release + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS.get(settings), Long.valueOf(8L)); + assertEquals(AnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD.get(settings), TimeValue.timeValueDays(9)); + assertEquals(AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(settings), Integer.valueOf(10)); + assertEquals(AnomalyDetectorSettings.COOLDOWN_MINUTES.get(settings), TimeValue.timeValueMinutes(11)); + assertEquals(AnomalyDetectorSettings.BACKOFF_MINUTES.get(settings), TimeValue.timeValueMinutes(12)); + assertEquals(AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY.get(settings), TimeValue.timeValueMillis(13)); + assertEquals(AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF.get(settings), Integer.valueOf(14)); + assertEquals(AnomalyDetectorSettings.MAX_RETRY_FOR_END_RUN_EXCEPTION.get(settings), Integer.valueOf(15)); + assertEquals(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(settings), Boolean.valueOf(true)); + assertEquals(AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE.get(settings), Double.valueOf(0.6D)); + // MAX_ENTITIES_FOR_PREVIEW uses default instead of legacy fallback + assertEquals(AnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW.get(settings), Integer.valueOf(5)); + // INDEX_PRESSURE_SOFT_LIMIT uses default instead of legacy fallback + assertEquals(AnomalyDetectorSettings.AD_INDEX_PRESSURE_SOFT_LIMIT.get(settings), Float.valueOf(0.6F)); + assertEquals(AnomalyDetectorSettings.AD_MAX_PRIMARY_SHARDS.get(settings), Integer.valueOf(21)); + // MAX_CACHE_MISS_HANDLING_PER_SECOND is removed in the new release + assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MAX_CACHE_MISS_HANDLING_PER_SECOND.get(settings), Integer.valueOf(22)); + assertEquals(AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE.get(settings), Integer.valueOf(23)); + assertEquals(AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR.get(settings), Integer.valueOf(24)); + assertEquals(AnomalyDetectorSettings.BATCH_TASK_PIECE_SIZE.get(settings), Integer.valueOf(25)); + assertEquals(AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS.get(settings), Integer.valueOf(26)); + + assertSettingDeprecationsAndWarnings( + new Setting[] { + LegacyOpenDistroAnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS, + LegacyOpenDistroAnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS, + LegacyOpenDistroAnomalyDetectorSettings.MAX_ANOMALY_FEATURES, + LegacyOpenDistroAnomalyDetectorSettings.REQUEST_TIMEOUT, + LegacyOpenDistroAnomalyDetectorSettings.DETECTION_INTERVAL, + LegacyOpenDistroAnomalyDetectorSettings.DETECTION_WINDOW_DELAY, + LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD, + LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS, + LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE, + LegacyOpenDistroAnomalyDetectorSettings.COOLDOWN_MINUTES, + LegacyOpenDistroAnomalyDetectorSettings.BACKOFF_MINUTES, + LegacyOpenDistroAnomalyDetectorSettings.BACKOFF_INITIAL_DELAY, + LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_BACKOFF, + LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_END_RUN_EXCEPTION, + LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD, + LegacyOpenDistroAnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE, + LegacyOpenDistroAnomalyDetectorSettings.MAX_PRIMARY_SHARDS, + LegacyOpenDistroAnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES, + LegacyOpenDistroAnomalyDetectorSettings.MAX_CACHE_MISS_HANDLING_PER_SECOND, + LegacyOpenDistroAnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE, + LegacyOpenDistroAnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS, + LegacyOpenDistroAnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR, + LegacyOpenDistroAnomalyDetectorSettings.BATCH_TASK_PIECE_SIZE } + ); + } +} diff --git a/src/test/java/org/opensearch/ad/stats/ADStatTests.java-e b/src/test/java/org/opensearch/ad/stats/ADStatTests.java-e new file mode 100644 index 000000000..1912f92ad --- /dev/null +++ b/src/test/java/org/opensearch/ad/stats/ADStatTests.java-e @@ -0,0 +1,69 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.stats; + +import java.util.function.Supplier; + +import org.junit.Test; +import org.opensearch.ad.stats.suppliers.CounterSupplier; +import org.opensearch.ad.stats.suppliers.SettableSupplier; +import org.opensearch.test.OpenSearchTestCase; + +public class ADStatTests extends OpenSearchTestCase { + + @Test + public void testIsClusterLevel() { + ADStat stat1 = new ADStat<>(true, new TestSupplier()); + assertTrue("isCluster returns the wrong value", stat1.isClusterLevel()); + ADStat stat2 = new ADStat<>(false, new TestSupplier()); + assertTrue("isCluster returns the wrong value", !stat2.isClusterLevel()); + } + + @Test + public void testGetValue() { + ADStat stat1 = new ADStat<>(false, new CounterSupplier()); + assertEquals("GetValue returns the incorrect value", 0L, (long) (stat1.getValue())); + + ADStat stat2 = new ADStat<>(false, new TestSupplier()); + assertEquals("GetValue returns the incorrect value", "test", stat2.getValue()); + } + + @Test + public void testSetValue() { + ADStat stat = new ADStat<>(false, new SettableSupplier()); + assertEquals("GetValue returns the incorrect value", 0L, (long) (stat.getValue())); + stat.setValue(10L); + assertEquals("GetValue returns the incorrect value", 10L, (long) stat.getValue()); + } + + @Test + public void testIncrement() { + ADStat incrementStat = new ADStat<>(false, new CounterSupplier()); + + for (Long i = 0L; i < 100; i++) { + assertEquals("increment does not work", i, incrementStat.getValue()); + incrementStat.increment(); + } + + // Ensure that no problems occur for a stat that cannot be incremented + ADStat nonIncStat = new ADStat<>(false, new TestSupplier()); + nonIncStat.increment(); + } + + private class TestSupplier implements Supplier { + TestSupplier() {} + + public String get() { + return "test"; + } + } +} diff --git a/src/test/java/org/opensearch/ad/stats/ADStatsResponseTests.java-e b/src/test/java/org/opensearch/ad/stats/ADStatsResponseTests.java-e new file mode 100644 index 000000000..194623bd5 --- /dev/null +++ b/src/test/java/org/opensearch/ad/stats/ADStatsResponseTests.java-e @@ -0,0 +1,123 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.stats; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Test; +import org.opensearch.action.FailedNodeException; +import org.opensearch.ad.transport.ADStatsNodeResponse; +import org.opensearch.ad.transport.ADStatsNodesResponse; +import org.opensearch.cluster.ClusterName; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.test.OpenSearchTestCase; + +public class ADStatsResponseTests extends OpenSearchTestCase { + @Test + public void testGetAndSetClusterStats() { + ADStatsResponse adStatsResponse = new ADStatsResponse(); + Map testClusterStats = new HashMap<>(); + testClusterStats.put("test_stat", 1L); + adStatsResponse.setClusterStats(testClusterStats); + assertEquals(testClusterStats, adStatsResponse.getClusterStats()); + } + + @Test + public void testGetAndSetADStatsNodesResponse() { + ADStatsResponse adStatsResponse = new ADStatsResponse(); + List responses = Collections.emptyList(); + List failures = Collections.emptyList(); + ADStatsNodesResponse adStatsNodesResponse = new ADStatsNodesResponse(ClusterName.DEFAULT, responses, failures); + adStatsResponse.setADStatsNodesResponse(adStatsNodesResponse); + assertEquals(adStatsNodesResponse, adStatsResponse.getADStatsNodesResponse()); + } + + @Test + public void testMerge() { + ADStatsResponse adStatsResponse1 = new ADStatsResponse(); + Map testClusterStats = new HashMap<>(); + testClusterStats.put("test_stat", 1L); + adStatsResponse1.setClusterStats(testClusterStats); + + ADStatsResponse adStatsResponse2 = new ADStatsResponse(); + List responses = Collections.emptyList(); + List failures = Collections.emptyList(); + ADStatsNodesResponse adStatsNodesResponse = new ADStatsNodesResponse(ClusterName.DEFAULT, responses, failures); + adStatsResponse2.setADStatsNodesResponse(adStatsNodesResponse); + + adStatsResponse1.merge(adStatsResponse2); + assertEquals(testClusterStats, adStatsResponse1.getClusterStats()); + assertEquals(adStatsNodesResponse, adStatsResponse1.getADStatsNodesResponse()); + + adStatsResponse2.merge(adStatsResponse1); + assertEquals(testClusterStats, adStatsResponse2.getClusterStats()); + assertEquals(adStatsNodesResponse, adStatsResponse2.getADStatsNodesResponse()); + + // Confirm merging with null does nothing + adStatsResponse1.merge(null); + assertEquals(testClusterStats, adStatsResponse1.getClusterStats()); + assertEquals(adStatsNodesResponse, adStatsResponse1.getADStatsNodesResponse()); + + // Confirm merging with self does nothing + adStatsResponse1.merge(adStatsResponse1); + assertEquals(testClusterStats, adStatsResponse1.getClusterStats()); + assertEquals(adStatsNodesResponse, adStatsResponse1.getADStatsNodesResponse()); + } + + @Test + public void testEquals() { + ADStatsResponse adStatsResponse1 = new ADStatsResponse(); + assertEquals(adStatsResponse1, adStatsResponse1); + assertNotEquals(null, adStatsResponse1); + assertNotEquals(1, adStatsResponse1); + ADStatsResponse adStatsResponse2 = new ADStatsResponse(); + assertEquals(adStatsResponse1, adStatsResponse2); + Map testClusterStats = new HashMap<>(); + testClusterStats.put("test_stat", 1L); + adStatsResponse1.setClusterStats(testClusterStats); + assertNotEquals(adStatsResponse1, adStatsResponse2); + } + + @Test + public void testHashCode() { + ADStatsResponse adStatsResponse1 = new ADStatsResponse(); + ADStatsResponse adStatsResponse2 = new ADStatsResponse(); + assertEquals(adStatsResponse1.hashCode(), adStatsResponse2.hashCode()); + Map testClusterStats = new HashMap<>(); + testClusterStats.put("test_stat", 1L); + adStatsResponse1.setClusterStats(testClusterStats); + assertNotEquals(adStatsResponse1.hashCode(), adStatsResponse2.hashCode()); + } + + @Test + public void testToXContent() throws IOException { + ADStatsResponse adStatsResponse = new ADStatsResponse(); + Map testClusterStats = new HashMap<>(); + testClusterStats.put("test_stat", 1); + adStatsResponse.setClusterStats(testClusterStats); + List responses = Collections.emptyList(); + List failures = Collections.emptyList(); + ADStatsNodesResponse adStatsNodesResponse = new ADStatsNodesResponse(ClusterName.DEFAULT, responses, failures); + adStatsResponse.setADStatsNodesResponse(adStatsNodesResponse); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + adStatsResponse.toXContent(builder); + XContentParser parser = createParser(builder); + assertEquals(1, parser.map().get("test_stat")); + } +} diff --git a/src/test/java/org/opensearch/ad/stats/ADStatsTests.java-e b/src/test/java/org/opensearch/ad/stats/ADStatsTests.java-e new file mode 100644 index 000000000..0d8150683 --- /dev/null +++ b/src/test/java/org/opensearch/ad/stats/ADStatsTests.java-e @@ -0,0 +1,186 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.stats; + +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_MODEL_SIZE_PER_NODE; + +import java.time.Clock; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.HybridThresholdingModel; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.stats.suppliers.CounterSupplier; +import org.opensearch.ad.stats.suppliers.IndexStatusSupplier; +import org.opensearch.ad.stats.suppliers.ModelsOnNodeSupplier; +import org.opensearch.ad.util.IndexUtils; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.stats.StatNames; + +import test.org.opensearch.ad.util.MLUtil; +import test.org.opensearch.ad.util.RandomModelStateConfig; + +import com.amazon.randomcutforest.RandomCutForest; + +public class ADStatsTests extends OpenSearchTestCase { + + private Map> statsMap; + private ADStats adStats; + private RandomCutForest rcf; + private HybridThresholdingModel thresholdingModel; + private String clusterStatName1, clusterStatName2; + private String nodeStatName1, nodeStatName2; + + @Mock + private Clock clock; + + @Mock + private ModelManager modelManager; + + @Mock + private CacheProvider cacheProvider; + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + + // sampleSize * numberOfTrees has to be larger than 1. Otherwise, RCF reports errors. + rcf = RandomCutForest.builder().dimensions(1).sampleSize(2).numberOfTrees(1).build(); + thresholdingModel = new HybridThresholdingModel(1e-8, 1e-5, 200, 10_000, 2, 5_000_000); + + List> modelsInformation = new ArrayList<>( + Arrays + .asList( + new ModelState<>(rcf, "rcf-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock, 0f), + new ModelState<>(thresholdingModel, "thr-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock, 0f), + new ModelState<>(rcf, "rcf-model-2", "detector-2", ModelManager.ModelType.THRESHOLD.getName(), clock, 0f), + new ModelState<>(thresholdingModel, "thr-model-2", "detector-2", ModelManager.ModelType.THRESHOLD.getName(), clock, 0f) + ) + ); + + when(modelManager.getAllModels()).thenReturn(modelsInformation); + + ModelState entityModel1 = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ModelState entityModel2 = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + + List> entityModelsInformation = new ArrayList<>(Arrays.asList(entityModel1, entityModel2)); + EntityCache cache = mock(EntityCache.class); + when(cacheProvider.get()).thenReturn(cache); + when(cache.getAllModels()).thenReturn(entityModelsInformation); + + IndexUtils indexUtils = mock(IndexUtils.class); + + when(indexUtils.getIndexHealthStatus(anyString())).thenReturn("yellow"); + when(indexUtils.getNumberOfDocumentsInIndex(anyString())).thenReturn(100L); + + clusterStatName1 = "clusterStat1"; + clusterStatName2 = "clusterStat2"; + + nodeStatName1 = "nodeStat1"; + nodeStatName2 = "nodeStat2"; + + Settings settings = Settings.builder().put(MAX_MODEL_SIZE_PER_NODE.getKey(), 10).build(); + ClusterService clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(MAX_MODEL_SIZE_PER_NODE))) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + statsMap = new HashMap>() { + { + put(nodeStatName1, new ADStat<>(false, new CounterSupplier())); + put(nodeStatName2, new ADStat<>(false, new ModelsOnNodeSupplier(modelManager, cacheProvider, settings, clusterService))); + put(clusterStatName1, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index1"))); + put(clusterStatName2, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index2"))); + } + }; + + adStats = new ADStats(statsMap); + } + + @Test + public void testStatNamesGetNames() { + assertEquals("getNames of StatNames returns the incorrect number of stats", StatNames.getNames().size(), StatNames.values().length); + } + + @Test + public void testGetStats() { + Map> stats = adStats.getStats(); + + assertEquals("getStats returns the incorrect number of stats", stats.size(), statsMap.size()); + + for (Map.Entry> stat : stats.entrySet()) { + assertTrue( + "getStats returns incorrect stats", + adStats.getStats().containsKey(stat.getKey()) && adStats.getStats().get(stat.getKey()) == stat.getValue() + ); + } + } + + @Test + public void testGetStat() { + ADStat stat = adStats.getStat(clusterStatName1); + + assertTrue( + "getStat returns incorrect stat", + adStats.getStats().containsKey(clusterStatName1) && adStats.getStats().get(clusterStatName1) == stat + ); + } + + @Test + public void testGetNodeStats() { + Map> stats = adStats.getStats(); + Set> nodeStats = new HashSet<>(adStats.getNodeStats().values()); + + for (ADStat stat : stats.values()) { + assertTrue( + "getNodeStats returns incorrect stat", + (stat.isClusterLevel() && !nodeStats.contains(stat)) || (!stat.isClusterLevel() && nodeStats.contains(stat)) + ); + } + } + + @Test + public void testGetClusterStats() { + Map> stats = adStats.getStats(); + Set> clusterStats = new HashSet<>(adStats.getClusterStats().values()); + + for (ADStat stat : stats.values()) { + assertTrue( + "getClusterStats returns incorrect stat", + (stat.isClusterLevel() && clusterStats.contains(stat)) || (!stat.isClusterLevel() && !clusterStats.contains(stat)) + ); + } + } + +} diff --git a/src/test/java/org/opensearch/ad/stats/suppliers/CounterSupplierTests.java-e b/src/test/java/org/opensearch/ad/stats/suppliers/CounterSupplierTests.java-e new file mode 100644 index 000000000..333d50ffe --- /dev/null +++ b/src/test/java/org/opensearch/ad/stats/suppliers/CounterSupplierTests.java-e @@ -0,0 +1,25 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.stats.suppliers; + +import org.junit.Test; +import org.opensearch.test.OpenSearchTestCase; + +public class CounterSupplierTests extends OpenSearchTestCase { + @Test + public void testGetAndIncrement() { + CounterSupplier counterSupplier = new CounterSupplier(); + assertEquals("get returns incorrect value", (Long) 0L, counterSupplier.get()); + counterSupplier.increment(); + assertEquals("get returns incorrect value", (Long) 1L, counterSupplier.get()); + } +} diff --git a/src/test/java/org/opensearch/ad/stats/suppliers/IndexSupplierTests.java-e b/src/test/java/org/opensearch/ad/stats/suppliers/IndexSupplierTests.java-e new file mode 100644 index 000000000..cfdf71188 --- /dev/null +++ b/src/test/java/org/opensearch/ad/stats/suppliers/IndexSupplierTests.java-e @@ -0,0 +1,49 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.stats.suppliers; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.ad.util.IndexUtils; +import org.opensearch.test.OpenSearchTestCase; + +public class IndexSupplierTests extends OpenSearchTestCase { + private IndexUtils indexUtils; + private String indexStatus; + private String indexName; + + @Before + public void setup() { + indexUtils = mock(IndexUtils.class); + indexStatus = "yellow"; + indexName = "test-index"; + when(indexUtils.getIndexHealthStatus(indexName)).thenReturn(indexStatus); + } + + @Test + public void testGet() { + IndexStatusSupplier indexStatusSupplier1 = new IndexStatusSupplier(indexUtils, indexName); + assertEquals("Get method for IndexSupplier does not work", indexStatus, indexStatusSupplier1.get()); + + String invalidIndex = "invalid"; + when(indexUtils.getIndexHealthStatus(invalidIndex)).thenThrow(IllegalArgumentException.class); + IndexStatusSupplier indexStatusSupplier2 = new IndexStatusSupplier(indexUtils, invalidIndex); + assertEquals( + "Get method does not return correct response onf exception", + IndexStatusSupplier.UNABLE_TO_RETRIEVE_HEALTH_MESSAGE, + indexStatusSupplier2.get() + ); + } +} diff --git a/src/test/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java-e b/src/test/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java-e new file mode 100644 index 000000000..c0173593c --- /dev/null +++ b/src/test/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java-e @@ -0,0 +1,125 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.stats.suppliers; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_MODEL_SIZE_PER_NODE; +import static org.opensearch.ad.stats.suppliers.ModelsOnNodeSupplier.MODEL_STATE_STAT_KEYS; + +import java.time.Clock; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.HybridThresholdingModel; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.test.OpenSearchTestCase; + +import test.org.opensearch.ad.util.MLUtil; +import test.org.opensearch.ad.util.RandomModelStateConfig; + +import com.amazon.randomcutforest.RandomCutForest; + +public class ModelsOnNodeSupplierTests extends OpenSearchTestCase { + private RandomCutForest rcf; + private HybridThresholdingModel thresholdingModel; + private List> expectedResults; + private Clock clock; + private List> entityModelsInformation; + + @Mock + private ModelManager modelManager; + + @Mock + private CacheProvider cacheProvider; + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + + clock = Clock.systemUTC(); + rcf = RandomCutForest.builder().dimensions(1).sampleSize(2).numberOfTrees(1).build(); + thresholdingModel = new HybridThresholdingModel(1e-8, 1e-5, 200, 10_000, 2, 5_000_000); + + expectedResults = new ArrayList<>( + Arrays + .asList( + new ModelState<>(rcf, "rcf-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock, 0f), + new ModelState<>(thresholdingModel, "thr-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock, 0f), + new ModelState<>(rcf, "rcf-model-2", "detector-2", ModelManager.ModelType.THRESHOLD.getName(), clock, 0f), + new ModelState<>(thresholdingModel, "thr-model-2", "detector-2", ModelManager.ModelType.THRESHOLD.getName(), clock, 0f) + ) + ); + + when(modelManager.getAllModels()).thenReturn(expectedResults); + + ModelState entityModel1 = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ModelState entityModel2 = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + + entityModelsInformation = new ArrayList<>(Arrays.asList(entityModel1, entityModel2)); + EntityCache cache = mock(EntityCache.class); + when(cacheProvider.get()).thenReturn(cache); + when(cache.getAllModels()).thenReturn(entityModelsInformation); + } + + @Test + public void testGet() { + Settings settings = Settings.builder().put(MAX_MODEL_SIZE_PER_NODE.getKey(), 10).build(); + ClusterService clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(MAX_MODEL_SIZE_PER_NODE))) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + ModelsOnNodeSupplier modelsOnNodeSupplier = new ModelsOnNodeSupplier(modelManager, cacheProvider, settings, clusterService); + List> results = modelsOnNodeSupplier.get(); + assertEquals( + "get fails to return correct result", + Stream + .concat(expectedResults.stream(), entityModelsInformation.stream()) + .map( + modelState -> modelState + .getModelStateAsMap() + .entrySet() + .stream() + .filter(entry -> MODEL_STATE_STAT_KEYS.contains(entry.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)) + ) + .collect(Collectors.toList()), + results + ); + } + + @Test + public void testGetModelCount() { + ModelsOnNodeCountSupplier modelsOnNodeSupplier = new ModelsOnNodeCountSupplier(modelManager, cacheProvider); + assertEquals(6L, modelsOnNodeSupplier.get().longValue()); + } +} diff --git a/src/test/java/org/opensearch/ad/stats/suppliers/SettableSupplierTests.java-e b/src/test/java/org/opensearch/ad/stats/suppliers/SettableSupplierTests.java-e new file mode 100644 index 000000000..1cf1c9306 --- /dev/null +++ b/src/test/java/org/opensearch/ad/stats/suppliers/SettableSupplierTests.java-e @@ -0,0 +1,25 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.stats.suppliers; + +import org.junit.Test; +import org.opensearch.test.OpenSearchTestCase; + +public class SettableSupplierTests extends OpenSearchTestCase { + @Test + public void testSetGet() { + Long setCount = 15L; + SettableSupplier settableSupplier = new SettableSupplier(); + settableSupplier.set(setCount); + assertEquals("Get/Set fails", setCount, settableSupplier.get()); + } +} diff --git a/src/test/java/org/opensearch/ad/task/ADHCBatchTaskRunStateTests.java-e b/src/test/java/org/opensearch/ad/task/ADHCBatchTaskRunStateTests.java-e new file mode 100644 index 000000000..8f5a33c86 --- /dev/null +++ b/src/test/java/org/opensearch/ad/task/ADHCBatchTaskRunStateTests.java-e @@ -0,0 +1,56 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.task; + +import java.time.Instant; +import java.time.temporal.ChronoUnit; + +import org.opensearch.ad.ADUnitTestCase; + +public class ADHCBatchTaskRunStateTests extends ADUnitTestCase { + + private ADHCBatchTaskRunState taskRunState; + + @Override + public void setUp() throws Exception { + super.setUp(); + taskRunState = new ADHCBatchTaskRunState(); + } + + public void testExpiredForCancel() { + taskRunState.setHistoricalAnalysisCancelled(true); + // not expired if cancelled time is null + assertFalse(taskRunState.expired()); + + // expired if cancelled time is 10 minute ago, default time out is 60s + taskRunState.setCancelledTimeInMillis(Instant.now().minus(10, ChronoUnit.MINUTES).toEpochMilli()); + assertTrue(taskRunState.expired()); + + // not expired if cancelled time is 10 seconds ago, default time out is 60s + taskRunState.setCancelledTimeInMillis(Instant.now().minus(10, ChronoUnit.SECONDS).toEpochMilli()); + assertFalse(taskRunState.expired()); + } + + public void testExpiredForNotCancelled() { + taskRunState.setHistoricalAnalysisCancelled(false); + // not expired if last task run time is null + assertFalse(taskRunState.expired()); + + // expired if last task run time is 10 minute ago, default time out is 60s + taskRunState.setLastTaskRunTimeInMillis(Instant.now().minus(10, ChronoUnit.MINUTES).toEpochMilli()); + assertTrue(taskRunState.expired()); + + // not expired if task run time is 10 seconds ago, default time out is 60s + taskRunState.setLastTaskRunTimeInMillis(Instant.now().minus(10, ChronoUnit.SECONDS).toEpochMilli()); + assertFalse(taskRunState.expired()); + } +} diff --git a/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java-e b/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java-e new file mode 100644 index 000000000..ba9698d6a --- /dev/null +++ b/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java-e @@ -0,0 +1,702 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.task; + +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ad.MemoryTracker.Origin.HISTORICAL_SINGLE_ENTITY_DETECTOR; +import static org.opensearch.ad.constant.ADCommonMessages.DETECTOR_IS_RUNNING; +import static org.opensearch.ad.task.ADTaskCacheManager.TASK_RETRY_LIMIT; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.ad.MemoryTracker; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskState; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.common.exception.DuplicateTaskException; +import org.opensearch.timeseries.common.exception.LimitExceededException; + +import com.google.common.collect.ImmutableList; + +public class ADTaskCacheManagerTests extends OpenSearchTestCase { + private MemoryTracker memoryTracker; + private ADTaskCacheManager adTaskCacheManager; + private ClusterService clusterService; + private Settings settings; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + + settings = Settings + .builder() + .put(AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE.getKey(), 2) + .put(AnomalyDetectorSettings.MAX_CACHED_DELETED_TASKS.getKey(), 100) + .build(); + + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + settings, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays.asList(AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE, AnomalyDetectorSettings.MAX_CACHED_DELETED_TASKS) + ) + ) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + memoryTracker = mock(MemoryTracker.class); + adTaskCacheManager = new ADTaskCacheManager(settings, clusterService, memoryTracker); + } + + @Override + @After + public void tearDown() throws Exception { + super.tearDown(); + adTaskCacheManager.clear(); + } + + public void testPutTask() throws IOException { + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); + ADTask adTask = TestHelpers.randomAdTask(); + adTaskCacheManager.add(adTask); + assertEquals(1, adTaskCacheManager.size()); + assertTrue(adTaskCacheManager.contains(adTask.getTaskId())); + assertTrue(adTaskCacheManager.containsTaskOfDetector(adTask.getId())); + assertNotNull(adTaskCacheManager.getTRcfModel(adTask.getTaskId())); + assertNotNull(adTaskCacheManager.getShingle(adTask.getTaskId())); + assertFalse(adTaskCacheManager.isThresholdModelTrained(adTask.getTaskId())); + adTaskCacheManager.remove(adTask.getTaskId(), randomAlphaOfLength(5), randomAlphaOfLength(5)); + assertEquals(0, adTaskCacheManager.size()); + } + + public void testPutDuplicateTask() throws IOException { + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); + ADTask adTask1 = TestHelpers.randomAdTask(); + adTaskCacheManager.add(adTask1); + assertEquals(1, adTaskCacheManager.size()); + DuplicateTaskException e1 = expectThrows(DuplicateTaskException.class, () -> adTaskCacheManager.add(adTask1)); + assertEquals(DETECTOR_IS_RUNNING, e1.getMessage()); + + ADTask adTask2 = TestHelpers + .randomAdTask( + randomAlphaOfLength(5), + ADTaskState.INIT, + adTask1.getExecutionEndTime(), + adTask1.getStoppedBy(), + adTask1.getId(), + adTask1.getDetector(), + ADTaskType.HISTORICAL_SINGLE_ENTITY + ); + DuplicateTaskException e2 = expectThrows(DuplicateTaskException.class, () -> adTaskCacheManager.add(adTask2)); + assertEquals(DETECTOR_IS_RUNNING, e2.getMessage()); + } + + public void testPutMultipleEntityTasks() throws IOException { + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); + AnomalyDetector detector = TestHelpers + .randomAnomalyDetector( + ImmutableList.of(TestHelpers.randomFeature(true)), + null, + Instant.now(), + true, + ImmutableList.of(randomAlphaOfLength(5)) + ); + ADTask adTask1 = TestHelpers + .randomAdTask( + randomAlphaOfLength(5), + ADTaskState.CREATED, + Instant.now(), + null, + detector.getId(), + detector, + ADTaskType.HISTORICAL_HC_ENTITY + ); + ADTask adTask2 = TestHelpers + .randomAdTask( + randomAlphaOfLength(5), + ADTaskState.CREATED, + Instant.now(), + null, + detector.getId(), + detector, + ADTaskType.HISTORICAL_HC_ENTITY + ); + adTaskCacheManager.add(adTask1); + adTaskCacheManager.add(adTask2); + List tasks = adTaskCacheManager.getTasksOfDetector(detector.getId()); + assertEquals(2, tasks.size()); + } + + public void testAddDetector() throws IOException { + String detectorId = randomAlphaOfLength(10); + ADTask adTask = TestHelpers.randomAdTask(ADTaskType.HISTORICAL_HC_DETECTOR); + adTaskCacheManager.add(detectorId, adTask); + DuplicateTaskException e1 = expectThrows(DuplicateTaskException.class, () -> adTaskCacheManager.add(detectorId, adTask)); + assertEquals(DETECTOR_IS_RUNNING, e1.getMessage()); + } + + public void testPutTaskWithMemoryExceedLimit() { + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(false); + LimitExceededException exception = expectThrows( + LimitExceededException.class, + () -> adTaskCacheManager.add(TestHelpers.randomAdTask()) + ); + assertEquals("Not enough memory to run detector", exception.getMessage()); + } + + public void testThresholdModelTrained() throws IOException { + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); + ADTask adTask = TestHelpers.randomAdTask(); + adTaskCacheManager.add(adTask); + assertEquals(1, adTaskCacheManager.size()); + adTaskCacheManager.setThresholdModelTrained(adTask.getTaskId(), false); + verify(memoryTracker, never()).releaseMemory(anyLong(), anyBoolean(), eq(HISTORICAL_SINGLE_ENTITY_DETECTOR)); + adTaskCacheManager.setThresholdModelTrained(adTask.getTaskId(), true); + verify(memoryTracker, times(0)).releaseMemory(anyLong(), eq(true), eq(HISTORICAL_SINGLE_ENTITY_DETECTOR)); + } + + public void testTaskNotExist() { + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> adTaskCacheManager.getTRcfModel(randomAlphaOfLength(5)) + ); + assertEquals("AD task not in cache", e.getMessage()); + } + + public void testRemoveTaskWhichNotExist() { + adTaskCacheManager.remove(randomAlphaOfLength(5), randomAlphaOfLength(5), randomAlphaOfLength(5)); + verify(memoryTracker, never()).releaseMemory(anyLong(), anyBoolean(), eq(HISTORICAL_SINGLE_ENTITY_DETECTOR)); + } + + public void testExceedRunningTaskLimit() throws IOException { + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); + adTaskCacheManager.add(TestHelpers.randomAdTask()); + adTaskCacheManager.add(TestHelpers.randomAdTask()); + assertEquals(2, adTaskCacheManager.size()); + LimitExceededException e = expectThrows(LimitExceededException.class, () -> adTaskCacheManager.add(TestHelpers.randomAdTask())); + assertEquals("Exceed max historical analysis limit per node: 2", e.getMessage()); + } + + public void testCancelByDetectorIdWhichNotExist() { + String detectorId = randomAlphaOfLength(10); + String detectorTaskId = randomAlphaOfLength(10); + String reason = randomAlphaOfLength(10); + String userName = randomAlphaOfLength(5); + ADTaskCancellationState state = adTaskCacheManager.cancelByDetectorId(detectorId, detectorTaskId, reason, userName); + assertEquals("Wrong task cancellation state", ADTaskCancellationState.NOT_FOUND, state); + } + + public void testCancelByDetectorId() throws IOException { + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); + ADTask adTask = TestHelpers.randomAdTask(); + adTaskCacheManager.add(adTask); + String detectorId = adTask.getId(); + String detectorTaskId = adTask.getId(); + String reason = randomAlphaOfLength(10); + String userName = randomAlphaOfLength(5); + ADTaskCancellationState state = adTaskCacheManager.cancelByDetectorId(detectorId, detectorTaskId, reason, userName); + assertEquals("Wrong task cancellation state", ADTaskCancellationState.CANCELLED, state); + assertTrue(adTaskCacheManager.isCancelled(adTask.getTaskId())); + + state = adTaskCacheManager.cancelByDetectorId(detectorId, detectorTaskId, reason, userName); + assertEquals("Wrong task cancellation state", ADTaskCancellationState.ALREADY_CANCELLED, state); + } + + public void testTopEntityInited() throws IOException { + String detectorId = randomAlphaOfLength(10); + assertFalse(adTaskCacheManager.topEntityInited(detectorId)); + adTaskCacheManager.add(detectorId, TestHelpers.randomAdTask(ADTaskType.HISTORICAL_HC_DETECTOR)); + assertFalse(adTaskCacheManager.topEntityInited(detectorId)); + adTaskCacheManager.setTopEntityInited(detectorId); + assertTrue(adTaskCacheManager.topEntityInited(detectorId)); + } + + public void testEntityCache() throws IOException { + String detectorId = randomAlphaOfLength(10); + assertEquals(0, adTaskCacheManager.getPendingEntityCount(detectorId)); + assertEquals(0, adTaskCacheManager.getRunningEntityCount(detectorId)); + assertEquals(0, adTaskCacheManager.getTopEntityCount(detectorId).intValue()); + adTaskCacheManager.add(detectorId, TestHelpers.randomAdTask(ADTaskType.HISTORICAL_HC_DETECTOR)); + String entity1 = randomAlphaOfLength(5); + String entity2 = randomAlphaOfLength(5); + String entity3 = randomAlphaOfLength(5); + List entities = ImmutableList.of(entity1, entity2, entity3); + adTaskCacheManager.addPendingEntities(detectorId, entities); + adTaskCacheManager.setTopEntityCount(detectorId, entities.size()); + adTaskCacheManager.pollEntity(detectorId); + assertEquals(3, adTaskCacheManager.getTopEntityCount(detectorId).intValue()); + assertEquals(2, adTaskCacheManager.getPendingEntityCount(detectorId)); + assertEquals(0, adTaskCacheManager.getRunningEntityCount(detectorId)); + adTaskCacheManager.moveToRunningEntity(detectorId, entity1); + assertEquals(3, adTaskCacheManager.getTopEntityCount(detectorId).intValue()); + assertEquals(2, adTaskCacheManager.getPendingEntityCount(detectorId)); + assertEquals(1, adTaskCacheManager.getRunningEntityCount(detectorId)); + assertArrayEquals(new String[] { entity1 }, adTaskCacheManager.getRunningEntities(detectorId).toArray(new String[0])); + + assertFalse(adTaskCacheManager.removeRunningEntity(randomAlphaOfLength(10), entity1)); + assertFalse(adTaskCacheManager.removeRunningEntity(detectorId, randomAlphaOfLength(5))); + + assertTrue(adTaskCacheManager.removeRunningEntity(detectorId, entity1)); + assertEquals(2, adTaskCacheManager.getPendingEntityCount(detectorId)); + assertEquals(0, adTaskCacheManager.getRunningEntityCount(detectorId)); + + adTaskCacheManager.removeEntity(detectorId, entity2); + assertEquals(1, adTaskCacheManager.getPendingEntityCount(detectorId)); + + adTaskCacheManager.clearPendingEntities(detectorId); + assertEquals(0, adTaskCacheManager.getPendingEntityCount(detectorId)); + + assertNull(adTaskCacheManager.pollEntity(detectorId)); + + assertNull(adTaskCacheManager.getRunningEntities(randomAlphaOfLength(10))); + } + + public void testPollEntityWithNotExistingHCDetector() { + assertNull(adTaskCacheManager.pollEntity(randomAlphaOfLength(5))); + } + + public void testPushBackEntity() throws IOException { + String detectorId = randomAlphaOfLength(10); + adTaskCacheManager.add(detectorId, TestHelpers.randomAdTask(ADTaskType.HISTORICAL_HC_DETECTOR)); + String entity1 = randomAlphaOfLength(5); + String taskId = randomAlphaOfLength(5); + adTaskCacheManager.pushBackEntity(taskId, detectorId, entity1); + + assertFalse(adTaskCacheManager.exceedRetryLimit(detectorId, taskId)); + for (int i = 0; i < TASK_RETRY_LIMIT; i++) { + adTaskCacheManager.pushBackEntity(taskId, detectorId, entity1); + } + assertTrue(adTaskCacheManager.exceedRetryLimit(detectorId, taskId)); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> adTaskCacheManager.exceedRetryLimit(randomAlphaOfLength(10), taskId) + ); + assertEquals("Can't find HC detector in cache", exception.getMessage()); + } + + public void testRealtimeTaskCache() { + String detectorId1 = randomAlphaOfLength(10); + String newState = ADTaskState.INIT.name(); + Float newInitProgress = 0.0f; + String newError = randomAlphaOfLength(5); + assertTrue(adTaskCacheManager.isRealtimeTaskChangeNeeded(detectorId1, newState, newInitProgress, newError)); + // Init realtime task cache. + adTaskCacheManager.initRealtimeTaskCache(detectorId1, 60_000); + + adTaskCacheManager.updateRealtimeTaskCache(detectorId1, newState, newInitProgress, newError); + assertFalse(adTaskCacheManager.isRealtimeTaskChangeNeeded(detectorId1, newState, newInitProgress, newError)); + assertArrayEquals(new String[] { detectorId1 }, adTaskCacheManager.getDetectorIdsInRealtimeTaskCache()); + + String detectorId2 = randomAlphaOfLength(10); + adTaskCacheManager.updateRealtimeTaskCache(detectorId2, newState, newInitProgress, newError); + assertEquals(1, adTaskCacheManager.getDetectorIdsInRealtimeTaskCache().length); + adTaskCacheManager.initRealtimeTaskCache(detectorId2, 60_000); + adTaskCacheManager.updateRealtimeTaskCache(detectorId2, newState, newInitProgress, newError); + assertEquals(2, adTaskCacheManager.getDetectorIdsInRealtimeTaskCache().length); + + newState = ADTaskState.RUNNING.name(); + newInitProgress = 1.0f; + newError = "test error"; + assertTrue(adTaskCacheManager.isRealtimeTaskChangeNeeded(detectorId1, newState, newInitProgress, newError)); + adTaskCacheManager.updateRealtimeTaskCache(detectorId1, newState, newInitProgress, newError); + assertEquals(newInitProgress, adTaskCacheManager.getRealtimeTaskCache(detectorId1).getInitProgress()); + assertEquals(newState, adTaskCacheManager.getRealtimeTaskCache(detectorId1).getState()); + assertEquals(newError, adTaskCacheManager.getRealtimeTaskCache(detectorId1).getError()); + + adTaskCacheManager.removeRealtimeTaskCache(detectorId1); + assertArrayEquals(new String[] { detectorId2 }, adTaskCacheManager.getDetectorIdsInRealtimeTaskCache()); + + adTaskCacheManager.clearRealtimeTaskCache(); + assertEquals(0, adTaskCacheManager.getDetectorIdsInRealtimeTaskCache().length); + + } + + public void testUpdateRealtimeTaskCache() { + String detectorId = randomAlphaOfLength(5); + adTaskCacheManager.initRealtimeTaskCache(detectorId, 60_000); + adTaskCacheManager.updateRealtimeTaskCache(detectorId, null, null, null); + ADRealtimeTaskCache realtimeTaskCache = adTaskCacheManager.getRealtimeTaskCache(detectorId); + assertNull(realtimeTaskCache.getState()); + assertNull(realtimeTaskCache.getError()); + assertNull(realtimeTaskCache.getInitProgress()); + + String state = ADTaskState.RUNNING.name(); + Float initProgress = 0.1f; + String error = randomAlphaOfLength(5); + adTaskCacheManager.updateRealtimeTaskCache(detectorId, state, initProgress, error); + realtimeTaskCache = adTaskCacheManager.getRealtimeTaskCache(detectorId); + assertEquals(state, realtimeTaskCache.getState()); + assertEquals(error, realtimeTaskCache.getError()); + assertEquals(initProgress, realtimeTaskCache.getInitProgress()); + + state = ADTaskState.STOPPED.name(); + adTaskCacheManager.updateRealtimeTaskCache(detectorId, state, initProgress, error); + realtimeTaskCache = adTaskCacheManager.getRealtimeTaskCache(detectorId); + assertNull(realtimeTaskCache); + } + + public void testGetAndDecreaseEntityTaskLanes() throws IOException { + String detectorId = randomAlphaOfLength(10); + adTaskCacheManager.add(detectorId, TestHelpers.randomAdTask(ADTaskType.HISTORICAL_HC_DETECTOR)); + adTaskCacheManager.setAllowedRunningEntities(detectorId, 1); + assertEquals(1, adTaskCacheManager.getAndDecreaseEntityTaskLanes(detectorId)); + assertEquals(0, adTaskCacheManager.getAndDecreaseEntityTaskLanes(detectorId)); + } + + public void testDeletedTask() { + String taskId = randomAlphaOfLength(10); + adTaskCacheManager.addDeletedDetectorTask(taskId); + assertTrue(adTaskCacheManager.hasDeletedDetectorTask()); + assertEquals(taskId, adTaskCacheManager.pollDeletedDetectorTask()); + assertFalse(adTaskCacheManager.hasDeletedDetectorTask()); + } + + public void testAcquireTaskUpdatingSemaphore() throws IOException, InterruptedException { + String detectorId = randomAlphaOfLength(10); + ADTask adTask = TestHelpers.randomAdTask(ADTaskType.HISTORICAL_HC_DETECTOR); + adTaskCacheManager.add(detectorId, adTask); + assertTrue(adTaskCacheManager.tryAcquireTaskUpdatingSemaphore(detectorId, 0)); + assertFalse(adTaskCacheManager.tryAcquireTaskUpdatingSemaphore(detectorId, 0)); + } + + public void testGetTasksOfDetectorWithNonExistingDetectorId() throws IOException { + List tasksOfDetector = adTaskCacheManager.getTasksOfDetector(randomAlphaOfLength(10)); + assertEquals(0, tasksOfDetector.size()); + } + + public void testHistoricalTaskCache() throws IOException, InterruptedException { + List result = addHCDetectorCache(); + String detectorId = result.get(0); + String detectorTaskId = result.get(1); + assertTrue(adTaskCacheManager.containsTaskOfDetector(detectorId)); + assertTrue(adTaskCacheManager.isHCTaskCoordinatingNode(detectorId)); + assertTrue(adTaskCacheManager.isHCTaskRunning(detectorId)); + assertEquals(detectorTaskId, adTaskCacheManager.getDetectorTaskId(detectorId)); + Instant lastScaleEntityTaskLaneTime = adTaskCacheManager.getLastScaleEntityTaskLaneTime(detectorId); + assertNotNull(lastScaleEntityTaskLaneTime); + Thread.sleep(500); + adTaskCacheManager.refreshLastScaleEntityTaskLaneTime(detectorId); + assertTrue(lastScaleEntityTaskLaneTime.isBefore(adTaskCacheManager.getLastScaleEntityTaskLaneTime(detectorId))); + + adTaskCacheManager.removeHistoricalTaskCache(detectorId); + assertFalse(adTaskCacheManager.containsTaskOfDetector(detectorId)); + assertFalse(adTaskCacheManager.isHCTaskCoordinatingNode(detectorId)); + assertFalse(adTaskCacheManager.isHCTaskRunning(detectorId)); + assertNull(adTaskCacheManager.getDetectorTaskId(detectorId)); + assertNull(adTaskCacheManager.getLastScaleEntityTaskLaneTime(detectorId)); + } + + private List addHCDetectorCache() throws IOException { + when(memoryTracker.canAllocateReserved(anyLong())).thenReturn(true); + AnomalyDetector detector = TestHelpers + .randomAnomalyDetector( + ImmutableList.of(TestHelpers.randomFeature(true)), + null, + Instant.now(), + true, + ImmutableList.of(randomAlphaOfLength(5)) + ); + String detectorId = detector.getId(); + ADTask adDetectorTask = TestHelpers + .randomAdTask( + randomAlphaOfLength(5), + ADTaskState.CREATED, + Instant.now(), + null, + detectorId, + detector, + ADTaskType.HISTORICAL_HC_DETECTOR + ); + ADTask adEntityTask = TestHelpers + .randomAdTask( + randomAlphaOfLength(5), + ADTaskState.CREATED, + Instant.now(), + null, + detectorId, + detector, + ADTaskType.HISTORICAL_HC_ENTITY + ); + adTaskCacheManager.add(detectorId, adDetectorTask); + adTaskCacheManager.add(adEntityTask); + assertEquals(adEntityTask.getEntity(), adTaskCacheManager.getEntity(adEntityTask.getTaskId())); + String entityValue = randomAlphaOfLength(5); + adTaskCacheManager.addPendingEntities(detectorId, ImmutableList.of(entityValue)); + assertEquals(1, adTaskCacheManager.getUnfinishedEntityCount(detectorId)); + return ImmutableList.of(detectorId, adDetectorTask.getTaskId(), adEntityTask.getTaskId(), entityValue); + } + + public void testCancelHCDetector() throws IOException { + List result = addHCDetectorCache(); + String detectorId = result.get(0); + String entityTaskId = result.get(2); + assertFalse(adTaskCacheManager.isCancelled(entityTaskId)); + adTaskCacheManager.cancelByDetectorId(detectorId, "testDetectorTaskId", "testReason", "testUser"); + assertTrue(adTaskCacheManager.isCancelled(entityTaskId)); + } + + public void testTempEntity() throws IOException { + List result = addHCDetectorCache(); + String detectorId = result.get(0); + String entityValue = result.get(3); + assertEquals(1, adTaskCacheManager.getPendingEntityCount(detectorId)); + assertEquals(0, adTaskCacheManager.getTempEntityCount(detectorId)); + adTaskCacheManager.pollEntity(detectorId); + assertEquals(0, adTaskCacheManager.getPendingEntityCount(detectorId)); + assertEquals(1, adTaskCacheManager.getTempEntityCount(detectorId)); + adTaskCacheManager.addPendingEntity(detectorId, entityValue); + assertEquals(1, adTaskCacheManager.getPendingEntityCount(detectorId)); + assertEquals(0, adTaskCacheManager.getTempEntityCount(detectorId)); + assertNotNull(adTaskCacheManager.pollEntity(detectorId)); + assertNull(adTaskCacheManager.pollEntity(detectorId)); + } + + public void testScaleTaskSlots() throws IOException { + List result = addHCDetectorCache(); + String detectorId = result.get(0); + int taskSlots = randomIntBetween(6, 10); + int taskLaneLimit = randomIntBetween(1, 10); + adTaskCacheManager.setDetectorTaskLaneLimit(detectorId, taskLaneLimit); + adTaskCacheManager.setDetectorTaskSlots(detectorId, taskSlots); + assertEquals(taskSlots, adTaskCacheManager.getDetectorTaskSlots(detectorId)); + int scaleUpDelta = randomIntBetween(1, 5); + adTaskCacheManager.scaleUpDetectorTaskSlots(detectorId, scaleUpDelta); + assertEquals(taskSlots + scaleUpDelta, adTaskCacheManager.getDetectorTaskSlots(detectorId)); + int scaleDownDelta = randomIntBetween(1, 5); + int newTaskSlots = adTaskCacheManager.scaleDownHCDetectorTaskSlots(detectorId, scaleDownDelta); + assertEquals(taskSlots + scaleUpDelta - scaleDownDelta, newTaskSlots); + assertEquals(taskSlots + scaleUpDelta - scaleDownDelta, adTaskCacheManager.getDetectorTaskSlots(detectorId)); + int newTaskSlots2 = adTaskCacheManager.scaleDownHCDetectorTaskSlots(detectorId, taskSlots * 10); + assertEquals(newTaskSlots, adTaskCacheManager.getDetectorTaskSlots(detectorId)); + assertEquals(newTaskSlots, newTaskSlots2); + } + + public void testDetectorTaskSlots() { + assertEquals(0, adTaskCacheManager.getDetectorTaskSlots(randomAlphaOfLength(5))); + + String detectorId = randomAlphaOfLength(5); + adTaskCacheManager.setDetectorTaskLaneLimit(detectorId, randomIntBetween(1, 10)); + assertEquals(0, adTaskCacheManager.getDetectorTaskSlots(randomAlphaOfLength(5))); + int taskSlots = randomIntBetween(1, 10); + adTaskCacheManager.setDetectorTaskSlots(detectorId, taskSlots); + assertEquals(taskSlots, adTaskCacheManager.getDetectorTaskSlots(detectorId)); + } + + public void testTaskLanes() throws IOException { + List result = addHCDetectorCache(); + String detectorId = result.get(0); + int maxTaskLanes = randomIntBetween(1, 10); + adTaskCacheManager.setAllowedRunningEntities(detectorId, maxTaskLanes); + assertEquals(maxTaskLanes, adTaskCacheManager.getAvailableNewEntityTaskLanes(detectorId)); + } + + public void testRefreshRealtimeJobRunTime() throws InterruptedException { + String detectorId = randomAlphaOfLength(5); + adTaskCacheManager.initRealtimeTaskCache(detectorId, 1_000); + ADRealtimeTaskCache realtimeTaskCache = adTaskCacheManager.getRealtimeTaskCache(detectorId); + assertFalse(realtimeTaskCache.expired()); + Thread.sleep(3_000); + assertTrue(realtimeTaskCache.expired()); + adTaskCacheManager.refreshRealtimeJobRunTime(detectorId); + assertFalse(realtimeTaskCache.expired()); + } + + public void testAddDeletedDetector() { + String detectorId = randomAlphaOfLength(5); + adTaskCacheManager.addDeletedDetector(detectorId); + String polledDetectorId = adTaskCacheManager.pollDeletedDetector(); + assertEquals(detectorId, polledDetectorId); + assertNull(adTaskCacheManager.pollDeletedDetector()); + } + + public void testAddPendingEntitiesWithEmptyList() throws IOException { + String detectorId = randomAlphaOfLength(5); + expectThrows(IllegalArgumentException.class, () -> adTaskCacheManager.addPendingEntities(detectorId, null)); + + adTaskCacheManager.add(detectorId, TestHelpers.randomAdTask(ADTaskType.HISTORICAL_HC_DETECTOR)); + assertEquals(0, adTaskCacheManager.getPendingEntityCount(detectorId)); + adTaskCacheManager.addPendingEntities(detectorId, ImmutableList.of()); + assertEquals(0, adTaskCacheManager.getPendingEntityCount(detectorId)); + } + + public void testMoveToRunningEntity() throws IOException { + String detectorId = randomAlphaOfLength(5); + String entity = randomAlphaOfLength(5); + adTaskCacheManager.add(detectorId, TestHelpers.randomAdTask(ADTaskType.HISTORICAL_HC_DETECTOR)); + adTaskCacheManager.addPendingEntities(detectorId, ImmutableList.of(entity)); + assertEquals(1, adTaskCacheManager.getPendingEntityCount(detectorId)); + assertEquals(0, adTaskCacheManager.getTempEntityCount(detectorId)); + assertEquals(0, adTaskCacheManager.getRunningEntityCount(detectorId)); + adTaskCacheManager.moveToRunningEntity(detectorId, null); + assertEquals(1, adTaskCacheManager.getPendingEntityCount(detectorId)); + assertEquals(0, adTaskCacheManager.getTempEntityCount(detectorId)); + assertEquals(0, adTaskCacheManager.getRunningEntityCount(detectorId)); + adTaskCacheManager.moveToRunningEntity(detectorId, entity); + assertEquals(1, adTaskCacheManager.getPendingEntityCount(detectorId)); + assertEquals(0, adTaskCacheManager.getTempEntityCount(detectorId)); + assertEquals(0, adTaskCacheManager.getRunningEntityCount(detectorId)); + adTaskCacheManager.pollEntity(detectorId); + assertEquals(0, adTaskCacheManager.getPendingEntityCount(detectorId)); + assertEquals(1, adTaskCacheManager.getTempEntityCount(detectorId)); + assertEquals(0, adTaskCacheManager.getRunningEntityCount(detectorId)); + adTaskCacheManager.moveToRunningEntity(detectorId, entity); + assertEquals(0, adTaskCacheManager.getPendingEntityCount(detectorId)); + assertEquals(0, adTaskCacheManager.getTempEntityCount(detectorId)); + assertEquals(1, adTaskCacheManager.getRunningEntityCount(detectorId)); + } + + public void testRemoveEntity() throws IOException { + String detectorId = randomAlphaOfLength(5); + String entity = randomAlphaOfLength(5); + adTaskCacheManager.add(detectorId, TestHelpers.randomAdTask(ADTaskType.HISTORICAL_HC_DETECTOR)); + adTaskCacheManager.addPendingEntities(detectorId, ImmutableList.of(entity)); + assertEquals(1, adTaskCacheManager.getPendingEntityCount(detectorId)); + adTaskCacheManager.removeEntity(detectorId, null); + assertEquals(1, adTaskCacheManager.getPendingEntityCount(detectorId)); + + adTaskCacheManager.pollEntity(detectorId); + assertEquals(0, adTaskCacheManager.getPendingEntityCount(detectorId)); + assertEquals(1, adTaskCacheManager.getTempEntityCount(detectorId)); + adTaskCacheManager.removeEntity(detectorId, entity); + assertEquals(0, adTaskCacheManager.getPendingEntityCount(detectorId)); + assertEquals(0, adTaskCacheManager.getTempEntityCount(detectorId)); + assertEquals(0, adTaskCacheManager.getRunningEntityCount(detectorId)); + + adTaskCacheManager.addPendingEntities(detectorId, ImmutableList.of(entity)); + adTaskCacheManager.moveToRunningEntity(detectorId, entity); + assertEquals(1, adTaskCacheManager.getPendingEntityCount(detectorId)); + assertEquals(0, adTaskCacheManager.getTempEntityCount(detectorId)); + assertEquals(0, adTaskCacheManager.getRunningEntityCount(detectorId)); + adTaskCacheManager.pollEntity(detectorId); + adTaskCacheManager.moveToRunningEntity(detectorId, entity); + assertEquals(0, adTaskCacheManager.getPendingEntityCount(detectorId)); + assertEquals(0, adTaskCacheManager.getTempEntityCount(detectorId)); + assertEquals(1, adTaskCacheManager.getRunningEntityCount(detectorId)); + + adTaskCacheManager.removeEntity(detectorId, entity); + assertEquals(0, adTaskCacheManager.getPendingEntityCount(detectorId)); + assertEquals(0, adTaskCacheManager.getTempEntityCount(detectorId)); + assertEquals(0, adTaskCacheManager.getRunningEntityCount(detectorId)); + } + + public void testADHCBatchTaskRunStateCacheWithCancel() { + String detectorId = randomAlphaOfLength(5); + String detectorTaskId = randomAlphaOfLength(5); + assertFalse(adTaskCacheManager.detectorTaskStateExists(detectorId, detectorTaskId)); + assertNull(adTaskCacheManager.getDetectorTaskState(detectorId, detectorTaskId)); + + ADHCBatchTaskRunState state = adTaskCacheManager.getOrCreateHCDetectorTaskStateCache(detectorId, detectorTaskId); + assertTrue(adTaskCacheManager.detectorTaskStateExists(detectorId, detectorTaskId)); + assertEquals(ADTaskState.INIT.name(), state.getDetectorTaskState()); + assertFalse(state.expired()); + + state.setDetectorTaskState(ADTaskState.RUNNING.name()); + assertEquals(ADTaskState.RUNNING.name(), adTaskCacheManager.getDetectorTaskState(detectorId, detectorTaskId)); + + String cancelReason = randomAlphaOfLength(5); + String cancelledBy = randomAlphaOfLength(5); + adTaskCacheManager.cancelByDetectorId(detectorId, detectorTaskId, cancelReason, cancelledBy); + assertEquals(cancelReason, adTaskCacheManager.getCancelReasonForHC(detectorId, detectorTaskId)); + assertEquals(cancelledBy, adTaskCacheManager.getCancelledByForHC(detectorId, detectorTaskId)); + + expectThrows(IllegalArgumentException.class, () -> adTaskCacheManager.cancelByDetectorId(null, null, cancelReason, cancelledBy)); + expectThrows( + IllegalArgumentException.class, + () -> adTaskCacheManager.cancelByDetectorId(detectorId, null, cancelReason, cancelledBy) + ); + expectThrows( + IllegalArgumentException.class, + () -> adTaskCacheManager.cancelByDetectorId(null, detectorTaskId, cancelReason, cancelledBy) + ); + } + + public void testUpdateDetectorTaskState() { + String detectorId = randomAlphaOfLength(5); + String detectorTaskId = randomAlphaOfLength(5); + String newState = ADTaskState.RUNNING.name(); + + adTaskCacheManager.updateDetectorTaskState(detectorId, detectorTaskId, newState); + assertEquals(newState, adTaskCacheManager.getDetectorTaskState(detectorId, detectorTaskId)); + } + + public void testReleaseTaskUpdatingSemaphore() throws IOException, InterruptedException { + String detectorId = randomAlphaOfLength(5); + ADTask adTask = TestHelpers.randomAdTask(ADTaskType.HISTORICAL_HC_DETECTOR); + assertFalse(adTaskCacheManager.tryAcquireTaskUpdatingSemaphore(detectorId, 0)); + adTaskCacheManager.releaseTaskUpdatingSemaphore(detectorId); + assertFalse(adTaskCacheManager.tryAcquireTaskUpdatingSemaphore(detectorId, 0)); + + adTaskCacheManager.add(detectorId, adTask); + assertTrue(adTaskCacheManager.tryAcquireTaskUpdatingSemaphore(detectorId, 0)); + assertFalse(adTaskCacheManager.tryAcquireTaskUpdatingSemaphore(detectorId, 0)); + adTaskCacheManager.releaseTaskUpdatingSemaphore(detectorId); + assertTrue(adTaskCacheManager.tryAcquireTaskUpdatingSemaphore(detectorId, 0)); + } + + public void testCleanExpiredHCBatchTaskRunStates() { + String detectorId = randomAlphaOfLength(5); + String detectorTaskId = randomAlphaOfLength(5); + ADHCBatchTaskRunState state = adTaskCacheManager.getOrCreateHCDetectorTaskStateCache(detectorId, detectorTaskId); + state.setHistoricalAnalysisCancelled(true); + state.setCancelReason(randomAlphaOfLength(5)); + state.setCancelledBy(randomAlphaOfLength(5)); + state.setCancelledTimeInMillis(Instant.now().minus(10, ChronoUnit.MINUTES).toEpochMilli()); + assertTrue(adTaskCacheManager.isHistoricalAnalysisCancelledForHC(detectorId, detectorTaskId)); + + adTaskCacheManager.cleanExpiredHCBatchTaskRunStates(); + assertFalse(adTaskCacheManager.isHistoricalAnalysisCancelledForHC(detectorId, detectorTaskId)); + } + + public void testRemoveHistoricalTaskCacheIfNoRunningEntity() throws IOException { + String detectorId = randomAlphaOfLength(5); + adTaskCacheManager.removeHistoricalTaskCacheIfNoRunningEntity(detectorId); + + // Add pending entity should not impact remove historical task cache + ADTask adTask = TestHelpers.randomAdTask(ADTaskType.HISTORICAL_HC_DETECTOR); + adTaskCacheManager.add(detectorId, adTask); + adTaskCacheManager.addPendingEntity(detectorId, randomAlphaOfLength(5)); + adTaskCacheManager.removeHistoricalTaskCacheIfNoRunningEntity(detectorId); + + // Add pending entity and move it to running should impact remove historical task cache + adTaskCacheManager.add(detectorId, adTask); + String entity = randomAlphaOfLength(5); + adTaskCacheManager.addPendingEntity(detectorId, entity); + String pollEntity = adTaskCacheManager.pollEntity(detectorId); + assertEquals(entity, pollEntity); + expectThrows(IllegalArgumentException.class, () -> adTaskCacheManager.removeHistoricalTaskCacheIfNoRunningEntity(detectorId)); + } +} diff --git a/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java b/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java index f8b5dcfc6..f1b67e71e 100644 --- a/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java +++ b/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java @@ -103,21 +103,21 @@ import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.transport.TransportAddress; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.index.Index; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.index.get.GetResult; import org.opensearch.index.reindex.BulkByScrollResponse; import org.opensearch.index.reindex.DeleteByQueryAction; -import org.opensearch.index.shard.ShardId; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; diff --git a/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java-e b/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java-e new file mode 100644 index 000000000..f93fdcf52 --- /dev/null +++ b/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java-e @@ -0,0 +1,1627 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.task; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyFloat; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ad.constant.ADCommonName.ANOMALY_RESULT_INDEX_ALIAS; +import static org.opensearch.ad.constant.ADCommonName.DETECTION_STATE_INDEX; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.DELETE_AD_RESULT_WHEN_DELETE_DETECTOR; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.REQUEST_TIMEOUT; +import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; +import static org.opensearch.timeseries.TestHelpers.randomAdTask; +import static org.opensearch.timeseries.TestHelpers.randomAnomalyDetector; +import static org.opensearch.timeseries.TestHelpers.randomDetectionDateRange; +import static org.opensearch.timeseries.TestHelpers.randomDetector; +import static org.opensearch.timeseries.TestHelpers.randomFeature; +import static org.opensearch.timeseries.TestHelpers.randomIntervalSchedule; +import static org.opensearch.timeseries.TestHelpers.randomIntervalTimeConfiguration; +import static org.opensearch.timeseries.TestHelpers.randomUser; +import static org.opensearch.timeseries.constant.CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED; +import static org.opensearch.timeseries.model.Entity.createSingleAttributeEntity; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.function.Consumer; + +import org.apache.lucene.search.TotalHits; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.ad.ADUnitTestCase; +import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.mock.model.MockSimpleLog; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskAction; +import org.opensearch.ad.model.ADTaskProfile; +import org.opensearch.ad.model.ADTaskState; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.rest.handler.IndexAnomalyDetectorJobActionHandler; +import org.opensearch.ad.stats.InternalStatNames; +import org.opensearch.ad.transport.ADStatsNodeResponse; +import org.opensearch.ad.transport.ADStatsNodesResponse; +import org.opensearch.ad.transport.ADTaskProfileNodeResponse; +import org.opensearch.ad.transport.ADTaskProfileResponse; +import org.opensearch.ad.transport.AnomalyDetectorJobResponse; +import org.opensearch.ad.transport.ForwardADTaskRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.transport.TransportAddress; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.index.Index; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.index.get.GetResult; +import org.opensearch.index.reindex.BulkByScrollResponse; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.common.exception.DuplicateTaskException; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.transport.TransportResponseHandler; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +public class ADTaskManagerTests extends ADUnitTestCase { + + private Settings settings; + private Client client; + private ClusterService clusterService; + private ClusterSettings clusterSettings; + private DiscoveryNodeFilterer nodeFilter; + private ADIndexManagement detectionIndices; + private ADTaskCacheManager adTaskCacheManager; + private HashRing hashRing; + private ThreadContext.StoredContext context; + private ThreadContext threadContext; + private TransportService transportService; + private ADTaskManager adTaskManager; + private ThreadPool threadPool; + private IndexAnomalyDetectorJobActionHandler indexAnomalyDetectorJobActionHandler; + + private DateRange detectionDateRange; + private ActionListener listener; + + private DiscoveryNode node1; + private DiscoveryNode node2; + + private int maxRunningEntities; + private int maxBatchTaskPerNode; + + private String historicalTaskId = "test_historical_task_id"; + private String realtimeTaskId = "test_realtime_task_id"; + private String runningHistoricalHCTaskContent = "{\"_index\":\".opendistro-anomaly-detection-state\",\"_type\":\"_doc\",\"_id\":\"" + + historicalTaskId + + "\",\"_score\":1,\"_source\":{\"last_update_time\":1630999442827,\"state\":\"RUNNING\",\"detector_id\":" + + "\"tQQiv3sBr1GKRuDiJ5uI\",\"task_progress\":1,\"init_progress\":1,\"execution_start_time\":1630999393798," + + "\"is_latest\":true,\"task_type\":\"HISTORICAL_HC_DETECTOR\",\"coordinating_node\":\"u8aYDPmaS4Ccd08Ed0GNQw\"," + + "\"detector\":{\"name\":\"test-hc1\",\"description\":\"test\",\"time_field\":\"timestamp\"," + + "\"indices\":[\"nab_ec2_cpu_utilization_24ae8d\"],\"filter_query\":{\"match_all\":{\"boost\":1}}," + + "\"detection_interval\":{\"period\":{\"interval\":1,\"unit\":\"Minutes\"}},\"window_delay\":{\"period\"" + + ":{\"interval\":1,\"unit\":\"Minutes\"}},\"shingle_size\":8,\"schema_version\":0,\"feature_attributes\"" + + ":[{\"feature_id\":\"tAQiv3sBr1GKRuDiJ5ty\",\"feature_name\":\"F1\",\"feature_enabled\":true," + + "\"aggregation_query\":{\"f_1\":{\"sum\":{\"field\":\"value\"}}}}],\"ui_metadata\":{\"features\":" + + "{\"F1\":{\"featureType\":\"simple_aggs\",\"aggregationBy\":\"sum\",\"aggregationOf\":\"value\"}}," + + "\"filters\":[]},\"last_update_time\":1630999291783,\"category_field\":[\"type\"],\"detector_type\":" + + "\"MULTI_ENTITY\"},\"detection_date_range\":{\"start_time\":1628407291580,\"end_time\":1630999291580}," + + "\"entity\":[{\"name\":\"type\",\"value\":\"error10\"}],\"parent_task_id\":\"a1civ3sBwF58XZxvKrko\"," + + "\"worker_node\":\"DL5uOJV3TjOOAyh5hJXrCA\",\"current_piece\":1630999260000,\"execution_end_time\":1630999442814}}"; + + private String taskContent = "{\"_index\":\".opendistro-anomaly-detection-state\",\"_type\":\"_doc\",\"_id\":" + + "\"-1ojv3sBwF58XZxvtksG\",\"_score\":1,\"_source\":{\"last_update_time\":1630999442827,\"state\":\"FINISHED\"" + + ",\"detector_id\":\"tQQiv3sBr1GKRuDiJ5uI\",\"task_progress\":1,\"init_progress\":1,\"execution_start_time\"" + + ":1630999393798,\"is_latest\":true,\"task_type\":\"HISTORICAL_HC_ENTITY\",\"coordinating_node\":\"" + + "u8aYDPmaS4Ccd08Ed0GNQw\",\"detector\":{\"name\":\"test-hc1\",\"description\":\"test\",\"time_field\":\"" + + "timestamp\",\"indices\":[\"nab_ec2_cpu_utilization_24ae8d\"],\"filter_query\":{\"match_all\":{\"boost\":1}}" + + ",\"detection_interval\":{\"period\":{\"interval\":1,\"unit\":\"Minutes\"}},\"window_delay\":{\"period\":" + + "{\"interval\":1,\"unit\":\"Minutes\"}},\"shingle_size\":8,\"schema_version\":0,\"feature_attributes\":" + + "[{\"feature_id\":\"tAQiv3sBr1GKRuDiJ5ty\",\"feature_name\":\"F1\",\"feature_enabled\":true,\"aggregation_query" + + "\":{\"f_1\":{\"sum\":{\"field\":\"value\"}}}}],\"ui_metadata\":{\"features\":{\"F1\":{\"featureType\":" + + "\"simple_aggs\",\"aggregationBy\":\"sum\",\"aggregationOf\":\"value\"}},\"filters\":[]},\"last_update_time" + + "\":1630999291783,\"category_field\":[\"type\"],\"detector_type\":\"MULTI_ENTITY\"},\"detection_date_range\"" + + ":{\"start_time\":1628407291580,\"end_time\":1630999291580},\"entity\":[{\"name\":\"type\",\"value\":\"error10\"}]" + + ",\"parent_task_id\":\"a1civ3sBwF58XZxvKrko\",\"worker_node\":\"DL5uOJV3TjOOAyh5hJXrCA\",\"current_piece\"" + + ":1630999260000,\"execution_end_time\":1630999442814}}"; + @Captor + ArgumentCaptor> remoteResponseHandler; + + @Override + public void setUp() throws Exception { + super.setUp(); + Instant now = Instant.now(); + Instant startTime = now.minus(10, ChronoUnit.DAYS); + Instant endTime = now.minus(1, ChronoUnit.DAYS); + detectionDateRange = new DateRange(startTime, endTime); + + settings = Settings + .builder() + .put(MAX_OLD_AD_TASK_DOCS_PER_DETECTOR.getKey(), 2) + .put(BATCH_TASK_PIECE_INTERVAL_SECONDS.getKey(), 1) + .put(REQUEST_TIMEOUT.getKey(), TimeValue.timeValueSeconds(10)) + .build(); + + clusterSettings = clusterSetting( + settings, + MAX_OLD_AD_TASK_DOCS_PER_DETECTOR, + BATCH_TASK_PIECE_INTERVAL_SECONDS, + REQUEST_TIMEOUT, + DELETE_AD_RESULT_WHEN_DELETE_DETECTOR, + MAX_BATCH_TASK_PER_NODE, + MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS + ); + + maxBatchTaskPerNode = MAX_BATCH_TASK_PER_NODE.get(settings); + clusterService = spy(new ClusterService(settings, clusterSettings, null)); + + client = mock(Client.class); + nodeFilter = mock(DiscoveryNodeFilterer.class); + detectionIndices = mock(ADIndexManagement.class); + adTaskCacheManager = mock(ADTaskCacheManager.class); + hashRing = mock(HashRing.class); + transportService = mock(TransportService.class); + threadPool = mock(ThreadPool.class); + threadContext = new ThreadContext(settings); + when(threadPool.getThreadContext()).thenReturn(threadContext); + when(client.threadPool()).thenReturn(threadPool); + indexAnomalyDetectorJobActionHandler = mock(IndexAnomalyDetectorJobActionHandler.class); + adTaskManager = spy( + new ADTaskManager( + settings, + clusterService, + client, + TestHelpers.xContentRegistry(), + detectionIndices, + nodeFilter, + hashRing, + adTaskCacheManager, + threadPool + ) + ); + + listener = spy(new ActionListener() { + @Override + public void onResponse(AnomalyDetectorJobResponse bulkItemResponses) {} + + @Override + public void onFailure(Exception e) {} + }); + + node1 = new DiscoveryNode( + "nodeName1", + "node1", + new TransportAddress(TransportAddress.META_ADDRESS, 9300), + emptyMap(), + emptySet(), + Version.CURRENT + ); + node2 = new DiscoveryNode( + "nodeName2", + "node2", + new TransportAddress(TransportAddress.META_ADDRESS, 9300), + emptyMap(), + emptySet(), + Version.CURRENT + ); + maxRunningEntities = MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS.get(settings).intValue(); + + ThreadContext threadContext = new ThreadContext(settings); + context = threadContext.stashContext(); + } + + private void setupGetDetector(AnomalyDetector detector) { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener + .onResponse( + new GetResponse( + new GetResult( + CommonName.CONFIG_INDEX, + detector.getId(), + UNASSIGNED_SEQ_NO, + 0, + -1, + true, + BytesReference.bytes(detector.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)), + Collections.emptyMap(), + Collections.emptyMap() + ) + ) + ); + return null; + }).when(client).get(any(), any()); + } + + private void setupHashRingWithSameLocalADVersionNodes() { + doAnswer(invocation -> { + Consumer function = invocation.getArgument(0); + function.accept(new DiscoveryNode[] { node1, node2 }); + return null; + }).when(hashRing).getNodesWithSameLocalAdVersion(any(), any()); + } + + private void setupHashRingWithOwningNode() { + doAnswer(invocation -> { + Consumer> function = invocation.getArgument(1); + function.accept(Optional.of(node1)); + return null; + }).when(hashRing).buildAndGetOwningNodeWithSameLocalAdVersion(any(), any(), any()); + } + + public void testCreateTaskIndexNotAcknowledged() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(new CreateIndexResponse(false, false, ANOMALY_RESULT_INDEX_ALIAS)); + return null; + }).when(detectionIndices).initStateIndex(any()); + doReturn(false).when(detectionIndices).doesStateIndexExist(); + AnomalyDetector detector = randomDetector(ImmutableList.of(randomFeature(true)), randomAlphaOfLength(5), 1, randomAlphaOfLength(5)); + setupGetDetector(detector); + + adTaskManager.startDetector(detector, detectionDateRange, randomUser(), transportService, listener); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + String error = String.format(Locale.ROOT, CREATE_INDEX_NOT_ACKNOWLEDGED, DETECTION_STATE_INDEX); + assertEquals(error, exceptionCaptor.getValue().getMessage()); + } + + public void testCreateTaskIndexWithResourceAlreadyExistsException() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new ResourceAlreadyExistsException("index created")); + return null; + }).when(detectionIndices).initStateIndex(any()); + doReturn(false).when(detectionIndices).doesStateIndexExist(); + AnomalyDetector detector = randomDetector(ImmutableList.of(randomFeature(true)), randomAlphaOfLength(5), 1, randomAlphaOfLength(5)); + setupGetDetector(detector); + + adTaskManager.startDetector(detector, detectionDateRange, randomUser(), transportService, listener); + verify(listener, never()).onFailure(any()); + } + + public void testCreateTaskIndexWithException() throws IOException { + String error = randomAlphaOfLength(5); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onFailure(new RuntimeException(error)); + return null; + }).when(detectionIndices).initStateIndex(any()); + doReturn(false).when(detectionIndices).doesStateIndexExist(); + AnomalyDetector detector = randomDetector(ImmutableList.of(randomFeature(true)), randomAlphaOfLength(5), 1, randomAlphaOfLength(5)); + setupGetDetector(detector); + + adTaskManager.startDetector(detector, detectionDateRange, randomUser(), transportService, listener); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals(error, exceptionCaptor.getValue().getMessage()); + } + + public void testStartDetectorWithNoEnabledFeature() throws IOException { + AnomalyDetector detector = randomDetector( + ImmutableList.of(randomFeature(false)), + randomAlphaOfLength(5), + 1, + randomAlphaOfLength(5) + ); + setupGetDetector(detector); + + adTaskManager + .startDetector( + detector.getId(), + detectionDateRange, + indexAnomalyDetectorJobActionHandler, + randomUser(), + transportService, + context, + listener + ); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + } + + @SuppressWarnings("unchecked") + public void testStartDetectorForHistoricalAnalysis() throws IOException { + AnomalyDetector detector = randomDetector(ImmutableList.of(randomFeature(true)), randomAlphaOfLength(5), 1, randomAlphaOfLength(5)); + setupGetDetector(detector); + setupHashRingWithOwningNode(); + + adTaskManager + .startDetector( + detector.getId(), + detectionDateRange, + indexAnomalyDetectorJobActionHandler, + randomUser(), + transportService, + context, + listener + ); + verify(adTaskManager, times(1)).forwardRequestToLeadNode(any(), any(), any()); + } + + private void setupTaskSlots(int node1UsedTaskSlots, int node1AssignedTaskSLots, int node2UsedTaskSlots, int node2AssignedTaskSLots) { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener + .onResponse( + new ADStatsNodesResponse( + new ClusterName(randomAlphaOfLength(5)), + ImmutableList + .of( + new ADStatsNodeResponse( + node1, + ImmutableMap + .of( + InternalStatNames.AD_USED_BATCH_TASK_SLOT_COUNT.getName(), + node1UsedTaskSlots, + InternalStatNames.AD_DETECTOR_ASSIGNED_BATCH_TASK_SLOT_COUNT.getName(), + node1AssignedTaskSLots + ) + ), + new ADStatsNodeResponse( + node2, + ImmutableMap + .of( + InternalStatNames.AD_USED_BATCH_TASK_SLOT_COUNT.getName(), + node2UsedTaskSlots, + InternalStatNames.AD_DETECTOR_ASSIGNED_BATCH_TASK_SLOT_COUNT.getName(), + node2AssignedTaskSLots + ) + ) + ), + ImmutableList.of() + ) + ); + return null; + }).when(client).execute(any(), any(), any()); + } + + public void testCheckTaskSlotsWithNoAvailableTaskSlots() throws IOException { + ADTask adTask = randomAdTask( + randomAlphaOfLength(5), + ADTaskState.INIT, + Instant.now(), + randomAlphaOfLength(5), + TestHelpers.randomAnomalyDetectorUsingCategoryFields(randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5))) + ); + setupHashRingWithSameLocalADVersionNodes(); + + setupTaskSlots(0, maxBatchTaskPerNode, maxBatchTaskPerNode, maxBatchTaskPerNode); + + adTaskManager + .checkTaskSlots(adTask, adTask.getDetector(), detectionDateRange, randomUser(), ADTaskAction.START, transportService, listener); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue().getMessage().contains("No available task slot")); + } + + private void setupSearchTopEntities(int entitySize) { + List entities = new ArrayList<>(); + for (int i = 0; i < entitySize; i++) { + entities.add(createSingleAttributeEntity("category", "value" + i)); + } + } + + public void testCheckTaskSlotsWithAvailableTaskSlotsForHC() throws IOException { + ADTask adTask = randomAdTask( + randomAlphaOfLength(5), + ADTaskState.INIT, + Instant.now(), + randomAlphaOfLength(5), + TestHelpers.randomAnomalyDetectorUsingCategoryFields(randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5))) + ); + setupSearchTopEntities(4); + setupHashRingWithSameLocalADVersionNodes(); + + setupTaskSlots(0, maxBatchTaskPerNode, maxBatchTaskPerNode, maxBatchTaskPerNode - 1); + + adTaskManager + .checkTaskSlots(adTask, adTask.getDetector(), detectionDateRange, randomUser(), ADTaskAction.START, transportService, listener); + verify(adTaskManager, times(1)) + .startHistoricalAnalysis(eq(adTask.getDetector()), eq(detectionDateRange), any(), eq(1), eq(transportService), any()); + } + + public void testCheckTaskSlotsWithAvailableTaskSlotsForSingleEntityDetector() throws IOException { + ADTask adTask = randomAdTask( + randomAlphaOfLength(5), + ADTaskState.INIT, + Instant.now(), + randomAlphaOfLength(5), + TestHelpers.randomAnomalyDetectorUsingCategoryFields(randomAlphaOfLength(5), ImmutableList.of()) + ); + setupHashRingWithSameLocalADVersionNodes(); + + setupTaskSlots(0, 2, 2, 1); + + adTaskManager + .checkTaskSlots(adTask, adTask.getDetector(), detectionDateRange, randomUser(), ADTaskAction.START, transportService, listener); + verify(adTaskManager, times(1)) + .startHistoricalAnalysis(eq(adTask.getDetector()), eq(detectionDateRange), any(), eq(1), eq(transportService), any()); + } + + public void testCheckTaskSlotsWithAvailableTaskSlotsAndNoEntity() throws IOException { + ADTask adTask = randomAdTask( + randomAlphaOfLength(5), + ADTaskState.INIT, + Instant.now(), + randomAlphaOfLength(5), + TestHelpers.randomAnomalyDetectorUsingCategoryFields(randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5))) + ); + setupSearchTopEntities(0); + setupHashRingWithSameLocalADVersionNodes(); + + setupTaskSlots(0, 2, 2, 1); + + adTaskManager + .checkTaskSlots(adTask, adTask.getDetector(), detectionDateRange, randomUser(), ADTaskAction.START, transportService, listener); + verify(adTaskManager, times(1)).startHistoricalAnalysis(any(), any(), any(), anyInt(), any(), any()); + } + + public void testCheckTaskSlotsWithAvailableTaskSlotsForScale() throws IOException { + ADTask adTask = randomAdTask( + randomAlphaOfLength(5), + ADTaskState.INIT, + Instant.now(), + randomAlphaOfLength(5), + TestHelpers.randomAnomalyDetectorUsingCategoryFields(randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5))) + ); + setupSearchTopEntities(4); + setupHashRingWithSameLocalADVersionNodes(); + + setupTaskSlots(0, maxBatchTaskPerNode, maxBatchTaskPerNode, maxBatchTaskPerNode - 1); + + adTaskManager + .checkTaskSlots( + adTask, + adTask.getDetector(), + detectionDateRange, + randomUser(), + ADTaskAction.SCALE_ENTITY_TASK_SLOTS, + transportService, + listener + ); + verify(adTaskManager, times(1)).scaleTaskLaneOnCoordinatingNode(eq(adTask), eq(1), eq(transportService), any()); + } + + public void testDeleteDuplicateTasks() throws IOException { + ADTask adTask = randomAdTask(); + adTaskManager.handleADTaskException(adTask, new DuplicateTaskException("test")); + verify(client, times(1)).delete(any(), any()); + } + + public void testParseEntityForSingleCategoryHC() throws IOException { + ADTask adTask = randomAdTask( + randomAlphaOfLength(5), + ADTaskState.INIT, + Instant.now(), + randomAlphaOfLength(5), + TestHelpers.randomAnomalyDetectorUsingCategoryFields(randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5))) + ); + String entityValue = adTaskManager.convertEntityToString(adTask); + Entity entity = adTaskManager.parseEntityFromString(entityValue, adTask); + assertEquals(entity, adTask.getEntity()); + } + + public void testParseEntityForMultiCategoryHC() throws IOException { + ADTask adTask = randomAdTask( + randomAlphaOfLength(5), + ADTaskState.INIT, + Instant.now(), + randomAlphaOfLength(5), + TestHelpers + .randomAnomalyDetectorUsingCategoryFields( + randomAlphaOfLength(5), + ImmutableList.of(randomAlphaOfLength(5), randomAlphaOfLength(5)) + ) + ); + String entityValue = adTaskManager.convertEntityToString(adTask); + Entity entity = adTaskManager.parseEntityFromString(entityValue, adTask); + assertEquals(entity, adTask.getEntity()); + } + + public void testDetectorTaskSlotScaleUpDelta() { + String detectorId = randomAlphaOfLength(5); + DiscoveryNode[] eligibleDataNodes = new DiscoveryNode[] { node1, node2 }; + + // Scale down + when(hashRing.getNodesWithSameLocalAdVersion()).thenReturn(eligibleDataNodes); + when(adTaskCacheManager.getUnfinishedEntityCount(detectorId)).thenReturn(maxRunningEntities * 10); + int taskSlots = maxRunningEntities - 1; + when(adTaskCacheManager.getDetectorTaskSlots(detectorId)).thenReturn(taskSlots); + int delta = adTaskManager.detectorTaskSlotScaleDelta(detectorId); + assertEquals(maxRunningEntities - taskSlots, delta); + } + + public void testDetectorTaskSlotScaleDownDelta() { + String detectorId = randomAlphaOfLength(5); + DiscoveryNode[] eligibleDataNodes = new DiscoveryNode[] { node1, node2 }; + + // Scale down + when(hashRing.getNodesWithSameLocalAdVersion()).thenReturn(eligibleDataNodes); + when(adTaskCacheManager.getUnfinishedEntityCount(detectorId)).thenReturn(maxRunningEntities * 10); + int taskSlots = maxRunningEntities * 5; + when(adTaskCacheManager.getDetectorTaskSlots(detectorId)).thenReturn(taskSlots); + int delta = adTaskManager.detectorTaskSlotScaleDelta(detectorId); + assertEquals(maxRunningEntities - taskSlots, delta); + } + + @SuppressWarnings("unchecked") + public void testGetADTaskWithNullResponse() { + String taskId = randomAlphaOfLength(5); + ActionListener> actionListener = mock(ActionListener.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).get(any(), any()); + + adTaskManager.getADTask(taskId, actionListener); + verify(actionListener, times(1)).onResponse(eq(Optional.empty())); + } + + @SuppressWarnings("unchecked") + public void testGetADTaskWithNotExistTask() { + String taskId = randomAlphaOfLength(5); + ActionListener> actionListener = mock(ActionListener.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + GetResponse response = new GetResponse( + new GetResult( + CommonName.JOB_INDEX, + taskId, + UNASSIGNED_SEQ_NO, + 0, + -1, + false, + null, + Collections.emptyMap(), + Collections.emptyMap() + ) + ); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + adTaskManager.getADTask(taskId, actionListener); + verify(actionListener, times(1)).onResponse(eq(Optional.empty())); + } + + @SuppressWarnings("unchecked") + public void testGetADTaskWithIndexNotFoundException() { + String taskId = randomAlphaOfLength(5); + ActionListener> actionListener = mock(ActionListener.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new IndexNotFoundException("", "")); + return null; + }).when(client).get(any(), any()); + + adTaskManager.getADTask(taskId, actionListener); + verify(actionListener, times(1)).onResponse(eq(Optional.empty())); + } + + @SuppressWarnings("unchecked") + public void testGetADTaskWithIndexUnknownException() { + String taskId = randomAlphaOfLength(5); + ActionListener> actionListener = mock(ActionListener.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("test")); + return null; + }).when(client).get(any(), any()); + + adTaskManager.getADTask(taskId, actionListener); + verify(actionListener, times(1)).onFailure(any()); + } + + @SuppressWarnings("unchecked") + public void testGetADTaskWithExistingTask() { + String taskId = randomAlphaOfLength(5); + ActionListener> actionListener = mock(ActionListener.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + ADTask adTask = randomAdTask(); + GetResponse response = new GetResponse( + new GetResult( + CommonName.JOB_INDEX, + taskId, + UNASSIGNED_SEQ_NO, + 0, + -1, + true, + BytesReference.bytes(adTask.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)), + Collections.emptyMap(), + Collections.emptyMap() + ) + ); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + adTaskManager.getADTask(taskId, actionListener); + verify(actionListener, times(1)).onResponse(any()); + } + + @SuppressWarnings("unchecked") + public void testUpdateLatestRealtimeTaskOnCoordinatingNode() { + String detectorId = randomAlphaOfLength(5); + String state = ADTaskState.RUNNING.name(); + Long rcfTotalUpdates = randomLongBetween(200, 1000); + Long detectorIntervalInMinutes = 1L; + String error = randomAlphaOfLength(5); + ActionListener actionListener = mock(ActionListener.class); + doReturn(node1).when(clusterService).localNode(); + when(adTaskCacheManager.isRealtimeTaskChangeNeeded(anyString(), anyString(), anyFloat(), anyString())).thenReturn(true); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(new UpdateResponse(ShardId.fromString("[test][1]"), "1", 0L, 1L, 1L, DocWriteResponse.Result.UPDATED)); + return null; + }).when(adTaskManager).updateLatestADTask(anyString(), any(), anyMap(), any()); + adTaskManager + .updateLatestRealtimeTaskOnCoordinatingNode( + detectorId, + state, + rcfTotalUpdates, + detectorIntervalInMinutes, + error, + actionListener + ); + verify(actionListener, times(1)).onResponse(any()); + } + + public void testGetLocalADTaskProfilesByDetectorId() { + doReturn(node1).when(clusterService).localNode(); + when(adTaskCacheManager.isHCTaskRunning(anyString())).thenReturn(true); + when(adTaskCacheManager.isHCTaskCoordinatingNode(anyString())).thenReturn(true); + List tasksOfDetector = ImmutableList.of(randomAlphaOfLength(5)); + when(adTaskCacheManager.getTasksOfDetector(anyString())).thenReturn(tasksOfDetector); + Deque>> shingle = new LinkedBlockingDeque<>(); + when(adTaskCacheManager.getShingle(anyString())).thenReturn(shingle); + ThresholdedRandomCutForest trcf = mock(ThresholdedRandomCutForest.class); + when(adTaskCacheManager.getTRcfModel(anyString())).thenReturn(trcf); + RandomCutForest rcf = mock(RandomCutForest.class); + when(trcf.getForest()).thenReturn(rcf); + when(rcf.getTotalUpdates()).thenReturn(randomLongBetween(100, 1000)); + when(adTaskCacheManager.isThresholdModelTrained(anyString())).thenReturn(true); + when(adTaskCacheManager.getThresholdModelTrainingDataSize(anyString())).thenReturn(randomIntBetween(100, 1000)); + when(adTaskCacheManager.getModelSize(anyString())).thenReturn(randomLongBetween(100, 1000)); + Entity entity = createSingleAttributeEntity(randomAlphaOfLength(5), randomAlphaOfLength(5)); + when(adTaskCacheManager.getEntity(anyString())).thenReturn(entity); + String detectorId = randomAlphaOfLength(5); + + ExecutorService executeService = mock(ExecutorService.class); + when(threadPool.executor(anyString())).thenReturn(executeService); + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(executeService).execute(any()); + + ADTaskProfile taskProfile = adTaskManager.getLocalADTaskProfilesByDetectorId(detectorId); + assertEquals(1, taskProfile.getEntityTaskProfiles().size()); + verify(adTaskCacheManager, times(1)).cleanExpiredHCBatchTaskRunStates(); + } + + @SuppressWarnings("unchecked") + public void testRemoveStaleRunningEntity() throws IOException { + ActionListener actionListener = mock(ActionListener.class); + ADTask adTask = randomAdTask(); + String entity = randomAlphaOfLength(5); + ExecutorService executeService = mock(ExecutorService.class); + when(threadPool.executor(anyString())).thenReturn(executeService); + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(executeService).execute(any()); + when(adTaskCacheManager.removeRunningEntity(anyString(), anyString())).thenReturn(true); + when(adTaskCacheManager.getPendingEntityCount(anyString())).thenReturn(randomIntBetween(1, 10)); + adTaskManager.removeStaleRunningEntity(adTask, entity, transportService, actionListener); + verify(adTaskManager, times(1)).runNextEntityForHCADHistorical(any(), any(), any()); + + when(adTaskCacheManager.removeRunningEntity(anyString(), anyString())).thenReturn(false); + when(adTaskCacheManager.hasEntity(anyString())).thenReturn(false); + adTaskManager.removeStaleRunningEntity(adTask, entity, transportService, actionListener); + verify(adTaskManager, times(1)).setHCDetectorTaskDone(any(), any(), any()); + + when(adTaskCacheManager.hasEntity(anyString())).thenReturn(true); + adTaskManager.removeStaleRunningEntity(adTask, entity, transportService, actionListener); + verify(adTaskManager, times(1)).setHCDetectorTaskDone(any(), any(), any()); + } + + public void testResetLatestFlagAsFalse() throws IOException { + List adTasks = new ArrayList<>(); + adTaskManager.resetLatestFlagAsFalse(adTasks); + verify(client, never()).execute(any(), any(), any()); + + ADTask adTask = randomAdTask(); + adTasks.add(adTask); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + BulkItemResponse[] responses = new BulkItemResponse[1]; + ShardId shardId = new ShardId(new Index("index_name", "uuid"), 0); + responses[0] = new BulkItemResponse( + 0, + randomFrom(DocWriteRequest.OpType.values()), + new IndexResponse(shardId, "id", 1, 1, 1, true) + ); + listener.onResponse(new BulkResponse(responses, 1)); + return null; + }).when(client).execute(any(), any(), any()); + adTaskManager.resetLatestFlagAsFalse(adTasks); + verify(client, times(1)).execute(any(), any(), any()); + } + + public void testCleanADResultOfDeletedDetectorWithNoDeletedDetector() { + when(adTaskCacheManager.pollDeletedDetector()).thenReturn(null); + adTaskManager.cleanADResultOfDeletedDetector(); + verify(client, never()).execute(eq(DeleteByQueryAction.INSTANCE), any(), any()); + } + + public void testCleanADResultOfDeletedDetectorWithException() { + String detectorId = randomAlphaOfLength(5); + when(adTaskCacheManager.pollDeletedDetector()).thenReturn(detectorId); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException("test")); + return null; + }).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + BulkByScrollResponse deleteByQueryResponse = mock(BulkByScrollResponse.class); + listener.onResponse(deleteByQueryResponse); + return null; + }).when(client).execute(any(), any(), any()); + + settings = Settings + .builder() + .put(MAX_OLD_AD_TASK_DOCS_PER_DETECTOR.getKey(), 2) + .put(BATCH_TASK_PIECE_INTERVAL_SECONDS.getKey(), 1) + .put(REQUEST_TIMEOUT.getKey(), TimeValue.timeValueSeconds(10)) + .put(DELETE_AD_RESULT_WHEN_DELETE_DETECTOR.getKey(), true) + .build(); + + clusterSettings = clusterSetting( + settings, + MAX_OLD_AD_TASK_DOCS_PER_DETECTOR, + BATCH_TASK_PIECE_INTERVAL_SECONDS, + REQUEST_TIMEOUT, + DELETE_AD_RESULT_WHEN_DELETE_DETECTOR, + MAX_BATCH_TASK_PER_NODE, + MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS + ); + + clusterService = spy(new ClusterService(settings, clusterSettings, null)); + + ADTaskManager adTaskManager = spy( + new ADTaskManager( + settings, + clusterService, + client, + TestHelpers.xContentRegistry(), + detectionIndices, + nodeFilter, + hashRing, + adTaskCacheManager, + threadPool + ) + ); + adTaskManager.cleanADResultOfDeletedDetector(); + verify(client, times(1)).execute(eq(DeleteByQueryAction.INSTANCE), any(), any()); + verify(adTaskCacheManager, times(1)).addDeletedDetector(eq(detectorId)); + + adTaskManager.cleanADResultOfDeletedDetector(); + verify(client, times(2)).execute(eq(DeleteByQueryAction.INSTANCE), any(), any()); + verify(adTaskCacheManager, times(1)).addDeletedDetector(eq(detectorId)); + } + + public void testMaintainRunningHistoricalTasksWithOwningNodeIsNotLocalNode() { + // Test no owning node + when(hashRing.getOwningNodeWithHighestAdVersion(anyString())).thenReturn(Optional.empty()); + adTaskManager.maintainRunningHistoricalTasks(transportService, 10); + verify(client, never()).search(any(), any()); + + // Test owning node is not local node + when(hashRing.getOwningNodeWithHighestAdVersion(anyString())).thenReturn(Optional.of(node2)); + doReturn(node1).when(clusterService).localNode(); + adTaskManager.maintainRunningHistoricalTasks(transportService, 10); + verify(client, never()).search(any(), any()); + } + + public void testMaintainRunningHistoricalTasksWithNoRunningTask() { + when(hashRing.getOwningNodeWithHighestAdVersion(anyString())).thenReturn(Optional.of(node1)); + doReturn(node1).when(clusterService).localNode(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), Float.NaN); + InternalSearchResponse response = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 1 + ); + SearchResponse searchResponse = new SearchResponse( + response, + null, + 1, + 1, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + adTaskManager.maintainRunningHistoricalTasks(transportService, 10); + verify(client, times(1)).search(any(), any()); + } + + public void testMaintainRunningHistoricalTasksWithRunningTask() { + when(hashRing.getOwningNodeWithHighestAdVersion(anyString())).thenReturn(Optional.of(node1)); + doReturn(node1).when(clusterService).localNode(); + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(threadPool).schedule(any(), any(), anyString()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + SearchHit task = SearchHit.fromXContent(TestHelpers.parser(runningHistoricalHCTaskContent)); + SearchHits searchHits = new SearchHits(new SearchHit[] { task }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), Float.NaN); + InternalSearchResponse response = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 1 + ); + SearchResponse searchResponse = new SearchResponse( + response, + null, + 1, + 1, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + adTaskManager.maintainRunningHistoricalTasks(transportService, 10); + verify(client, times(1)).search(any(), any()); + } + + public void testMaintainRunningRealtimeTasksWithNoRealtimeTask() { + when(adTaskCacheManager.getDetectorIdsInRealtimeTaskCache()).thenReturn(null); + adTaskManager.maintainRunningRealtimeTasks(); + verify(adTaskCacheManager, never()).removeRealtimeTaskCache(anyString()); + + when(adTaskCacheManager.getDetectorIdsInRealtimeTaskCache()).thenReturn(new String[0]); + adTaskManager.maintainRunningRealtimeTasks(); + verify(adTaskCacheManager, never()).removeRealtimeTaskCache(anyString()); + } + + public void testMaintainRunningRealtimeTasks() { + String detectorId1 = randomAlphaOfLength(5); + String detectorId2 = randomAlphaOfLength(5); + String detectorId3 = randomAlphaOfLength(5); + when(adTaskCacheManager.getDetectorIdsInRealtimeTaskCache()).thenReturn(new String[] { detectorId1, detectorId2, detectorId3 }); + when(adTaskCacheManager.getRealtimeTaskCache(detectorId1)).thenReturn(null); + + ADRealtimeTaskCache cacheOfDetector2 = mock(ADRealtimeTaskCache.class); + when(cacheOfDetector2.expired()).thenReturn(false); + when(adTaskCacheManager.getRealtimeTaskCache(detectorId2)).thenReturn(cacheOfDetector2); + + ADRealtimeTaskCache cacheOfDetector3 = mock(ADRealtimeTaskCache.class); + when(cacheOfDetector3.expired()).thenReturn(true); + when(adTaskCacheManager.getRealtimeTaskCache(detectorId3)).thenReturn(cacheOfDetector3); + + adTaskManager.maintainRunningRealtimeTasks(); + verify(adTaskCacheManager, times(1)).removeRealtimeTaskCache(anyString()); + } + + @SuppressWarnings("unchecked") + public void testStartHistoricalAnalysisWithNoOwningNode() throws IOException { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableList.of()); + DateRange detectionDateRange = TestHelpers.randomDetectionDateRange(); + User user = null; + int availableTaskSlots = randomIntBetween(1, 10); + ActionListener listener = mock(ActionListener.class); + doAnswer(invocation -> { + Consumer> function = invocation.getArgument(1); + function.accept(Optional.empty()); + return null; + }).when(hashRing).buildAndGetOwningNodeWithSameLocalAdVersion(anyString(), any(), any()); + adTaskManager.startHistoricalAnalysis(detector, detectionDateRange, user, availableTaskSlots, transportService, listener); + verify(listener, times(1)).onFailure(any()); + } + + @SuppressWarnings("unchecked") + public void testGetAndExecuteOnLatestADTasksWithRunningRealtimeTaskWithTaskStopped() throws IOException { + String detectorId = randomAlphaOfLength(5); + Consumer> function = mock(Consumer.class); + AnomalyDetector detector = TestHelpers + .randomDetector( + ImmutableList.of(randomFeature(true)), + randomAlphaOfLength(5), + randomIntBetween(1, 10), + MockSimpleLog.TIME_FIELD, + ImmutableList.of(randomAlphaOfLength(5)) + ); + ADTask adTask = ADTask + .builder() + .taskId(randomAlphaOfLength(5)) + .taskType(ADTaskType.HISTORICAL_HC_DETECTOR.name()) + .detectorId(randomAlphaOfLength(5)) + .detector(detector) + .entity(null) + .state(ADTaskState.RUNNING.name()) + .taskProgress(0.5f) + .initProgress(1.0f) + .currentPiece(Instant.now().truncatedTo(ChronoUnit.SECONDS).minus(randomIntBetween(1, 100), ChronoUnit.MINUTES)) + .executionStartTime(Instant.now().truncatedTo(ChronoUnit.SECONDS).minus(100, ChronoUnit.MINUTES)) + .isLatest(true) + .error(randomAlphaOfLength(5)) + .checkpointId(randomAlphaOfLength(5)) + .lastUpdateTime(Instant.now().truncatedTo(ChronoUnit.SECONDS)) + .startedBy(randomAlphaOfLength(5)) + .lastUpdateTime(Instant.now().truncatedTo(ChronoUnit.SECONDS)) + .coordinatingNode(node1.getId()) + .build(); + ADTaskProfile profile = new ADTaskProfile( + adTask, + randomInt(), + randomLong(), + randomBoolean(), + randomInt(), + randomLong(), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + randomInt(), + randomBoolean(), + randomInt(), + randomInt(), + randomInt(), + ImmutableList.of(randomAlphaOfLength(5)), + Instant.now().toEpochMilli() + ); + setupGetAndExecuteOnLatestADTasks(profile); + adTaskManager + .getAndExecuteOnLatestADTasks( + detectorId, + null, + null, + ADTaskType.ALL_DETECTOR_TASK_TYPES, + function, + transportService, + true, + 10, + listener + ); + verify(client, times(2)).update(any(), any()); + } + + @SuppressWarnings("unchecked") + public void testGetAndExecuteOnLatestADTasksWithRunningHistoricalTask() throws IOException { + String detectorId = randomAlphaOfLength(5); + Consumer> function = mock(Consumer.class); + AnomalyDetector detector = TestHelpers + .randomDetector( + ImmutableList.of(randomFeature(true)), + randomAlphaOfLength(5), + randomIntBetween(1, 10), + MockSimpleLog.TIME_FIELD, + ImmutableList.of(randomAlphaOfLength(5)) + ); + ADTask adTask = ADTask + .builder() + .taskId(historicalTaskId) + .taskType(ADTaskType.HISTORICAL_HC_DETECTOR.name()) + .detectorId(randomAlphaOfLength(5)) + .detector(detector) + .entity(null) + .state(ADTaskState.RUNNING.name()) + .taskProgress(0.5f) + .initProgress(1.0f) + .currentPiece(Instant.now().truncatedTo(ChronoUnit.SECONDS).minus(randomIntBetween(1, 100), ChronoUnit.MINUTES)) + .executionStartTime(Instant.now().truncatedTo(ChronoUnit.SECONDS).minus(100, ChronoUnit.MINUTES)) + .isLatest(true) + .error(randomAlphaOfLength(5)) + .checkpointId(randomAlphaOfLength(5)) + .lastUpdateTime(Instant.now().truncatedTo(ChronoUnit.SECONDS)) + .startedBy(randomAlphaOfLength(5)) + .lastUpdateTime(Instant.now().truncatedTo(ChronoUnit.SECONDS)) + .coordinatingNode(node1.getId()) + .build(); + ADTaskProfile profile = new ADTaskProfile( + adTask, + randomInt(), + randomLong(), + randomBoolean(), + randomInt(), + randomLong(), + randomAlphaOfLength(5), + historicalTaskId, + randomAlphaOfLength(5), + randomInt(), + randomBoolean(), + randomInt(), + randomInt(), + 2, + ImmutableList.of(randomAlphaOfLength(5), randomAlphaOfLength(5)), + Instant.now().toEpochMilli() + ); + setupGetAndExecuteOnLatestADTasks(profile); + adTaskManager + .getAndExecuteOnLatestADTasks( + detectorId, + null, + null, + ADTaskType.ALL_DETECTOR_TASK_TYPES, + function, + transportService, + true, + 10, + listener + ); + verify(client, times(2)).update(any(), any()); + } + + @SuppressWarnings("unchecked") + private void setupGetAndExecuteOnLatestADTasks(ADTaskProfile adTaskProfile) { + String runningRealtimeHCTaskContent = runningHistoricalHCTaskContent + .replace(ADTaskType.HISTORICAL_HC_DETECTOR.name(), ADTaskType.REALTIME_HC_DETECTOR.name()) + .replace(historicalTaskId, realtimeTaskId); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + SearchHit historicalTask = SearchHit.fromXContent(TestHelpers.parser(runningHistoricalHCTaskContent)); + SearchHit realtimeTask = SearchHit.fromXContent(TestHelpers.parser(runningRealtimeHCTaskContent)); + SearchHits searchHits = new SearchHits( + new SearchHit[] { historicalTask, realtimeTask }, + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + Float.NaN + ); + InternalSearchResponse response = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 1 + ); + SearchResponse searchResponse = new SearchResponse( + response, + null, + 1, + 1, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + String detectorId = randomAlphaOfLength(5); + Consumer> function = mock(Consumer.class); + ActionListener listener = mock(ActionListener.class); + + doAnswer(invocation -> { + Consumer getNodeFunction = invocation.getArgument(0); + getNodeFunction.accept(new DiscoveryNode[] { node1, node2 }); + return null; + }).when(hashRing).getAllEligibleDataNodesWithKnownAdVersion(any(), any()); + + doAnswer(invocation -> { + ActionListener taskProfileResponseListener = invocation.getArgument(2); + AnomalyDetector detector = TestHelpers + .randomDetector( + ImmutableList.of(randomFeature(true)), + randomAlphaOfLength(5), + randomIntBetween(1, 10), + MockSimpleLog.TIME_FIELD, + ImmutableList.of(randomAlphaOfLength(5)) + ); + ADTaskProfileNodeResponse nodeResponse = new ADTaskProfileNodeResponse(node1, adTaskProfile, Version.CURRENT); + ImmutableList nodes = ImmutableList.of(nodeResponse); + ADTaskProfileResponse taskProfileResponse = new ADTaskProfileResponse(new ClusterName("test"), nodes, ImmutableList.of()); + taskProfileResponseListener.onResponse(taskProfileResponse); + return null; + }).doAnswer(invocation -> { + ActionListener updateResponselistener = invocation.getArgument(2); + BulkByScrollResponse response = mock(BulkByScrollResponse.class); + when(response.getBulkFailures()).thenReturn(null); + updateResponselistener.onResponse(response); + return null; + }).when(client).execute(any(), any(), any()); + when(nodeFilter.getEligibleDataNodes()).thenReturn(new DiscoveryNode[] { node1, node2 }); + + doAnswer(invocation -> { + ActionListener updateResponselistener = invocation.getArgument(1); + UpdateResponse response = new UpdateResponse(ShardId.fromString("[test][1]"), "1", 0L, 1L, 1L, DocWriteResponse.Result.UPDATED); + updateResponselistener.onResponse(response); + return null; + }).when(client).update(any(), any()); + + doAnswer(invocation -> { + ActionListener getResponselistener = invocation.getArgument(1); + GetResponse response = new GetResponse( + new GetResult( + CommonName.JOB_INDEX, + detectorId, + UNASSIGNED_SEQ_NO, + 0, + -1, + true, + BytesReference + .bytes( + new AnomalyDetectorJob( + detectorId, + randomIntervalSchedule(), + randomIntervalTimeConfiguration(), + false, + Instant.now().minusSeconds(60), + Instant.now(), + Instant.now(), + 60L, + TestHelpers.randomUser(), + null + ).toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS) + ), + Collections.emptyMap(), + Collections.emptyMap() + ) + ); + getResponselistener.onResponse(response); + return null; + }).when(client).get(any(), any()); + } + + @SuppressWarnings("unchecked") + public void testCreateADTaskDirectlyWithException() throws IOException { + ADTask adTask = randomAdTask(ADTaskType.HISTORICAL_HC_DETECTOR); + Consumer function = mock(Consumer.class); + ActionListener listener = mock(ActionListener.class); + doThrow(new RuntimeException("test")).when(client).index(any(), any()); + + adTaskManager.createADTaskDirectly(adTask, function, listener); + verify(listener, times(1)).onFailure(any()); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new RuntimeException("test")); + return null; + }).when(client).index(any(), any()); + adTaskManager.createADTaskDirectly(adTask, function, listener); + verify(listener, times(2)).onFailure(any()); + } + + public void testCleanChildTasksAndADResultsOfDeletedTaskWithNoDeletedDetectorTask() { + when(adTaskCacheManager.hasDeletedDetectorTask()).thenReturn(false); + adTaskManager.cleanChildTasksAndADResultsOfDeletedTask(); + verify(client, never()).execute(any(), any(), any()); + } + + public void testCleanChildTasksAndADResultsOfDeletedTaskWithNullTask() { + when(adTaskCacheManager.hasDeletedDetectorTask()).thenReturn(true); + when(adTaskCacheManager.pollDeletedDetectorTask()).thenReturn(null); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException("test")); + return null; + }).when(client).execute(any(), any(), any()); + + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(threadPool).schedule(any(), any(), any()); + + adTaskManager.cleanChildTasksAndADResultsOfDeletedTask(); + verify(client, never()).execute(any(), any(), any()); + } + + public void testCleanChildTasksAndADResultsOfDeletedTaskWithFailToDeleteADResult() { + when(adTaskCacheManager.hasDeletedDetectorTask()).thenReturn(true); + when(adTaskCacheManager.pollDeletedDetectorTask()).thenReturn(randomAlphaOfLength(5)); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException("test")); + return null; + }).when(client).execute(any(), any(), any()); + + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(threadPool).schedule(any(), any(), any()); + + adTaskManager.cleanChildTasksAndADResultsOfDeletedTask(); + verify(client, times(1)).execute(any(), any(), any()); + } + + public void testCleanChildTasksAndADResultsOfDeletedTask() { + when(adTaskCacheManager.hasDeletedDetectorTask()).thenReturn(true); + when(adTaskCacheManager.pollDeletedDetectorTask()).thenReturn(randomAlphaOfLength(5)).thenReturn(null); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + BulkByScrollResponse response = mock(BulkByScrollResponse.class); + actionListener.onResponse(response); + return null; + }).when(client).execute(any(), any(), any()); + + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(threadPool).schedule(any(), any(), any()); + + adTaskManager.cleanChildTasksAndADResultsOfDeletedTask(); + verify(client, times(2)).execute(any(), any(), any()); + } + + @SuppressWarnings("unchecked") + public void testDeleteADTasks() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + BulkByScrollResponse response = mock(BulkByScrollResponse.class); + when(response.getBulkFailures()).thenReturn(null); + actionListener.onResponse(response); + return null; + }).when(client).execute(any(), any(), any()); + + String detectorId = randomAlphaOfLength(5); + ExecutorFunction function = mock(ExecutorFunction.class); + ActionListener listener = mock(ActionListener.class); + adTaskManager.deleteADTasks(detectorId, function, listener); + verify(function, times(1)).execute(); + } + + @SuppressWarnings("unchecked") + public void testDeleteADTasksWithBulkFailures() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + BulkByScrollResponse response = mock(BulkByScrollResponse.class); + List failures = ImmutableList + .of( + new BulkItemResponse.Failure( + DETECTION_STATE_INDEX, + randomAlphaOfLength(5), + new VersionConflictEngineException(new ShardId(DETECTION_STATE_INDEX, "", 1), "id", "test") + ) + ); + when(response.getBulkFailures()).thenReturn(failures); + actionListener.onResponse(response); + return null; + }).when(client).execute(any(), any(), any()); + + String detectorId = randomAlphaOfLength(5); + ExecutorFunction function = mock(ExecutorFunction.class); + ActionListener listener = mock(ActionListener.class); + adTaskManager.deleteADTasks(detectorId, function, listener); + verify(listener, times(1)).onFailure(any()); + } + + @SuppressWarnings("unchecked") + public void testDeleteADTasksWithException() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new IndexNotFoundException(DETECTION_STATE_INDEX)); + return null; + }).doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException("test")); + return null; + }).when(client).execute(any(), any(), any()); + + String detectorId = randomAlphaOfLength(5); + ExecutorFunction function = mock(ExecutorFunction.class); + ActionListener listener = mock(ActionListener.class); + + adTaskManager.deleteADTasks(detectorId, function, listener); + verify(function, times(1)).execute(); + verify(listener, never()).onFailure(any()); + + adTaskManager.deleteADTasks(detectorId, function, listener); + verify(function, times(1)).execute(); + verify(listener, times(1)).onFailure(any()); + } + + @SuppressWarnings("unchecked") + public void testScaleUpTaskSlots() throws IOException { + ADTask adTask = randomAdTask(ADTaskType.HISTORICAL_HC_ENTITY); + ActionListener listener = mock(ActionListener.class); + when(adTaskCacheManager.getAvailableNewEntityTaskLanes(anyString())).thenReturn(0); + doReturn(2).when(adTaskManager).detectorTaskSlotScaleDelta(anyString()); + when(adTaskCacheManager.getLastScaleEntityTaskLaneTime(anyString())).thenReturn(null); + + assertEquals(0, adTaskManager.scaleTaskSlots(adTask, transportService, listener)); + + when(adTaskCacheManager.getLastScaleEntityTaskLaneTime(anyString())).thenReturn(Instant.now()); + assertEquals(2, adTaskManager.scaleTaskSlots(adTask, transportService, listener)); + + when(adTaskCacheManager.getLastScaleEntityTaskLaneTime(anyString())).thenReturn(Instant.now().minus(10, ChronoUnit.DAYS)); + assertEquals(2, adTaskManager.scaleTaskSlots(adTask, transportService, listener)); + verify(adTaskCacheManager, times(1)).refreshLastScaleEntityTaskLaneTime(anyString()); + verify(adTaskManager, times(1)).forwardScaleTaskSlotRequestToLeadNode(any(), any(), any()); + } + + @SuppressWarnings("unchecked") + public void testForwardRequestToLeadNodeWithNotExistingNode() throws IOException { + ADTask adTask = randomAdTask(ADTaskType.HISTORICAL_HC_ENTITY); + ForwardADTaskRequest forwardADTaskRequest = new ForwardADTaskRequest(adTask, ADTaskAction.APPLY_FOR_TASK_SLOTS); + ActionListener listener = mock(ActionListener.class); + doAnswer(invocation -> { + Consumer> function = invocation.getArgument(1); + function.accept(Optional.empty()); + return null; + }).when(hashRing).buildAndGetOwningNodeWithSameLocalAdVersion(any(), any(), any()); + + adTaskManager.forwardRequestToLeadNode(forwardADTaskRequest, transportService, listener); + verify(listener, times(1)).onFailure(any()); + } + + @SuppressWarnings("unchecked") + public void testScaleTaskLaneOnCoordinatingNode() { + ADTask adTask = mock(ADTask.class); + when(adTask.getCoordinatingNode()).thenReturn(node1.getId()); + when(nodeFilter.getEligibleDataNodes()).thenReturn(new DiscoveryNode[] { node1, node2 }); + ActionListener listener = mock(ActionListener.class); + adTaskManager.scaleTaskLaneOnCoordinatingNode(adTask, 2, transportService, listener); + } + + @SuppressWarnings("unchecked") + public void testStartDetectorWithException() throws IOException { + AnomalyDetector detector = randomAnomalyDetector(ImmutableList.of(randomFeature(true))); + DateRange detectionDateRange = randomDetectionDateRange(); + User user = null; + ActionListener listener = mock(ActionListener.class); + when(detectionIndices.doesStateIndexExist()).thenReturn(false); + doThrow(new RuntimeException("test")).when(detectionIndices).initStateIndex(any()); + adTaskManager.startDetector(detector, detectionDateRange, user, transportService, listener); + verify(listener, times(1)).onFailure(any()); + } + + @SuppressWarnings("unchecked") + public void testStopDetectorWithNonExistingDetector() { + String detectorId = randomAlphaOfLength(5); + boolean historical = true; + ActionListener listener = mock(ActionListener.class); + doAnswer(invocation -> { + Consumer> function = invocation.getArgument(1); + function.accept(Optional.empty()); + return null; + }).when(adTaskManager).getDetector(anyString(), any(), any()); + adTaskManager.stopDetector(detectorId, historical, indexAnomalyDetectorJobActionHandler, null, transportService, listener); + verify(listener, times(1)).onFailure(any()); + } + + @SuppressWarnings("unchecked") + public void testStopDetectorWithNonExistingTask() { + String detectorId = randomAlphaOfLength(5); + boolean historical = true; + ActionListener listener = mock(ActionListener.class); + doAnswer(invocation -> { + Consumer> function = invocation.getArgument(1); + AnomalyDetector detector = randomAnomalyDetector(ImmutableList.of(randomFeature(true))); + function.accept(Optional.of(detector)); + return null; + }).when(adTaskManager).getDetector(anyString(), any(), any()); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(null); + return null; + }).when(client).search(any(), any()); + + adTaskManager.stopDetector(detectorId, historical, indexAnomalyDetectorJobActionHandler, null, transportService, listener); + verify(listener, times(1)).onFailure(any()); + } + + @SuppressWarnings("unchecked") + public void testStopDetectorWithTaskDone() { + String detectorId = randomAlphaOfLength(5); + boolean historical = true; + ActionListener listener = mock(ActionListener.class); + doAnswer(invocation -> { + Consumer> function = invocation.getArgument(1); + AnomalyDetector detector = randomAnomalyDetector(ImmutableList.of(randomFeature(true))); + function.accept(Optional.of(detector)); + return null; + }).when(adTaskManager).getDetector(anyString(), any(), any()); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + SearchHit task = SearchHit.fromXContent(TestHelpers.parser(taskContent)); + SearchHits searchHits = new SearchHits(new SearchHit[] { task }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), Float.NaN); + InternalSearchResponse response = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 1 + ); + SearchResponse searchResponse = new SearchResponse( + response, + null, + 1, + 1, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + actionListener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + adTaskManager.stopDetector(detectorId, historical, indexAnomalyDetectorJobActionHandler, null, transportService, listener); + verify(listener, times(1)).onFailure(any()); + } + + @SuppressWarnings("unchecked") + public void testGetDetectorWithWrongContent() { + String detectorId = randomAlphaOfLength(5); + Consumer> function = mock(Consumer.class); + ActionListener listener = mock(ActionListener.class); + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + GetResponse response = new GetResponse( + new GetResult( + CommonName.CONFIG_INDEX, + detectorId, + UNASSIGNED_SEQ_NO, + 0, + -1, + true, + BytesReference + .bytes( + new MockSimpleLog(Instant.now(), 1.0, "127.0.0.1", "category", true, "test") + .toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS) + ), + Collections.emptyMap(), + Collections.emptyMap() + ) + ); + responseListener.onResponse(response); + return null; + }).when(client).get(any(), any()); + adTaskManager.getDetector(detectorId, function, listener); + verify(listener, times(1)).onFailure(any()); + } + + @SuppressWarnings("unchecked") + public void testDeleteTaskDocs() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + SearchHit task = SearchHit.fromXContent(TestHelpers.parser(taskContent)); + SearchHits searchHits = new SearchHits(new SearchHit[] { task }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), Float.NaN); + InternalSearchResponse response = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 1 + ); + SearchResponse searchResponse = new SearchResponse( + response, + null, + 1, + 1, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + actionListener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + BulkItemResponse[] responses = new BulkItemResponse[1]; + ShardId shardId = new ShardId(new Index("index_name", "uuid"), 0); + responses[0] = new BulkItemResponse( + 0, + randomFrom(DocWriteRequest.OpType.values()), + new IndexResponse(shardId, "id", 1, 1, 1, true) + ); + responseListener.onResponse(new BulkResponse(responses, 1)); + return null; + }).when(client).execute(any(), any(), any()); + + String detectorId = randomAlphaOfLength(5); + SearchRequest searchRequest = mock(SearchRequest.class); + ExecutorFunction function = mock(ExecutorFunction.class); + ActionListener listener = mock(ActionListener.class); + adTaskManager.deleteTaskDocs(detectorId, searchRequest, function, listener); + verify(adTaskCacheManager, times(1)).addDeletedDetectorTask(anyString()); + verify(function, times(1)).execute(); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultRequestTests.java-e b/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultRequestTests.java-e new file mode 100644 index 000000000..dd200f8f4 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultRequestTests.java-e @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.model.ADTask; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.TestHelpers; + +public class ADBatchAnomalyResultRequestTests extends OpenSearchTestCase { + + public void testInvalidRequestWithNullTaskIdAndDetectionDateRange() throws IOException { + ADTask adTask = TestHelpers.randomAdTask(); + adTask.setTaskId(null); + adTask.setDetectionDateRange(null); + ADBatchAnomalyResultRequest request = new ADBatchAnomalyResultRequest(adTask); + ActionRequestValidationException exception = request.validate(); + assertEquals( + "Validation Failed: 1: Task id can't be null;2: Detection date range can't be null for batch task;", + exception.getMessage() + ); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportActionTests.java index 1654ede0c..1e3a3506e 100644 --- a/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportActionTests.java @@ -29,8 +29,8 @@ import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.util.ExceptionUtil; -import org.opensearch.common.io.stream.NotSerializableExceptionWrapper; import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.EndRunException; diff --git a/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportActionTests.java-e b/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportActionTests.java-e new file mode 100644 index 000000000..ac94e7c31 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportActionTests.java-e @@ -0,0 +1,192 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.settings.ADEnabledSetting.AD_ENABLED; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE; +import static org.opensearch.timeseries.TestHelpers.HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.concurrent.TimeUnit; + +import org.junit.Before; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.get.GetResponse; +import org.opensearch.ad.HistoricalAnalysisIntegTestCase; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.util.ExceptionUtil; +import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; +import org.opensearch.common.settings.Settings; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.model.DateRange; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 2) +public class ADBatchAnomalyResultTransportActionTests extends HistoricalAnalysisIntegTestCase { + + private String testIndex; + private Instant startTime; + private Instant endTime; + private String type = "error"; + private int detectionIntervalInMinutes = 1; + private DateRange dateRange; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + testIndex = "test_historical_data"; + startTime = Instant.now().minus(10, ChronoUnit.DAYS); + endTime = Instant.now(); + dateRange = new DateRange(endTime, endTime.plus(10, ChronoUnit.DAYS)); + ingestTestData(testIndex, startTime, detectionIntervalInMinutes, type); + createDetectionStateIndex(); + } + + @Override + protected Settings nodeSettings(int nodeOrdinal) { + return Settings + .builder() + .put(super.nodeSettings(nodeOrdinal)) + .put(BATCH_TASK_PIECE_INTERVAL_SECONDS.getKey(), 1) + .put(MAX_BATCH_TASK_PER_NODE.getKey(), 1) + .build(); + } + + public void testAnomalyDetectorWithNullDetector() { + ADTask task = randomCreatedADTask(randomAlphaOfLength(5), null, dateRange); + ADBatchAnomalyResultRequest request = new ADBatchAnomalyResultRequest(task); + ActionRequestValidationException exception = expectThrows( + ActionRequestValidationException.class, + () -> client().execute(ADBatchAnomalyResultAction.INSTANCE, request).actionGet(30_000) + ); + assertTrue(exception.getMessage().contains("Detector can't be null")); + } + + public void testHistoricalAnalysisWithFutureDateRange() throws IOException, InterruptedException { + DateRange dateRange = new DateRange(endTime, endTime.plus(10, ChronoUnit.DAYS)); + testInvalidDetectionDateRange(dateRange); + } + + public void testHistoricalAnalysisWithInvalidHistoricalDateRange() throws IOException, InterruptedException { + DateRange dateRange = new DateRange(startTime.minus(10, ChronoUnit.DAYS), startTime); + testInvalidDetectionDateRange(dateRange); + } + + public void testHistoricalAnalysisWithSmallHistoricalDateRange() throws IOException, InterruptedException { + DateRange dateRange = new DateRange(startTime, startTime.plus(10, ChronoUnit.MINUTES)); + testInvalidDetectionDateRange(dateRange, "There is not enough data to train model"); + } + + public void testHistoricalAnalysisWithValidDateRange() throws IOException, InterruptedException { + DateRange dateRange = new DateRange(startTime, endTime); + ADBatchAnomalyResultRequest request = adBatchAnomalyResultRequest(dateRange); + client().execute(ADBatchAnomalyResultAction.INSTANCE, request).actionGet(5000); + Thread.sleep(20000); + GetResponse doc = getDoc(ADCommonName.DETECTION_STATE_INDEX, request.getAdTask().getTaskId()); + assertTrue(HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(doc.getSourceAsMap().get(ADTask.STATE_FIELD))); + } + + public void testHistoricalAnalysisWithNonExistingIndex() throws IOException { + ADBatchAnomalyResultRequest request = adBatchAnomalyResultRequest(new DateRange(startTime, endTime), randomAlphaOfLength(5)); + client().execute(ADBatchAnomalyResultAction.INSTANCE, request).actionGet(10_000); + } + + public void testHistoricalAnalysisExceedsMaxRunningTaskLimit() throws IOException, InterruptedException { + updateTransientSettings(ImmutableMap.of(MAX_BATCH_TASK_PER_NODE.getKey(), 1)); + updateTransientSettings(ImmutableMap.of(BATCH_TASK_PIECE_INTERVAL_SECONDS.getKey(), 5)); + DateRange dateRange = new DateRange(startTime, endTime); + int totalDataNodes = getDataNodes().size(); + for (int i = 0; i < totalDataNodes; i++) { + client().execute(ADBatchAnomalyResultAction.INSTANCE, adBatchAnomalyResultRequest(dateRange)).actionGet(5000); + } + waitUntil(() -> countDocs(ADCommonName.DETECTION_STATE_INDEX) >= totalDataNodes, 10, TimeUnit.SECONDS); + + ADBatchAnomalyResultRequest request = adBatchAnomalyResultRequest(dateRange); + try { + client().execute(ADBatchAnomalyResultAction.INSTANCE, request).actionGet(5000); + } catch (Exception e) { + assertTrue( + ExceptionUtil + .getErrorMessage(e) + .contains("All nodes' executing batch tasks exceeds limitation No eligible node to run detector") + ); + } + } + + public void testDisableADPlugin() throws IOException { + try { + updateTransientSettings(ImmutableMap.of(AD_ENABLED, false)); + ADBatchAnomalyResultRequest request = adBatchAnomalyResultRequest(new DateRange(startTime, endTime)); + RuntimeException exception = expectThrowsAnyOf( + ImmutableList.of(NotSerializableExceptionWrapper.class, EndRunException.class), + () -> client().execute(ADBatchAnomalyResultAction.INSTANCE, request).actionGet(10000) + ); + assertTrue(exception.getMessage(), exception.getMessage().contains("AD functionality is disabled")); + updateTransientSettings(ImmutableMap.of(AD_ENABLED, false)); + } finally { + // guarantee reset back to default + updateTransientSettings(ImmutableMap.of(AD_ENABLED, true)); + } + } + + public void testMultipleTasks() throws IOException, InterruptedException { + updateTransientSettings(ImmutableMap.of(MAX_BATCH_TASK_PER_NODE.getKey(), 2)); + + DateRange dateRange = new DateRange(startTime, endTime); + for (int i = 0; i < getDataNodes().size(); i++) { + client().execute(ADBatchAnomalyResultAction.INSTANCE, adBatchAnomalyResultRequest(dateRange)); + } + + ADBatchAnomalyResultRequest request = adBatchAnomalyResultRequest( + new DateRange(startTime, startTime.plus(2000, ChronoUnit.MINUTES)) + ); + client().execute(ADBatchAnomalyResultAction.INSTANCE, request).actionGet(5000); + Thread.sleep(25000); + GetResponse doc = getDoc(ADCommonName.DETECTION_STATE_INDEX, request.getAdTask().getTaskId()); + assertTrue(HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(doc.getSourceAsMap().get(ADTask.STATE_FIELD))); + updateTransientSettings(ImmutableMap.of(MAX_BATCH_TASK_PER_NODE.getKey(), 1)); + } + + private ADBatchAnomalyResultRequest adBatchAnomalyResultRequest(DateRange dateRange) throws IOException { + return adBatchAnomalyResultRequest(dateRange, testIndex); + } + + private ADBatchAnomalyResultRequest adBatchAnomalyResultRequest(DateRange dateRange, String indexName) throws IOException { + AnomalyDetector detector = TestHelpers + .randomDetector(ImmutableList.of(maxValueFeature()), indexName, detectionIntervalInMinutes, timeField); + ADTask adTask = randomCreatedADTask(randomAlphaOfLength(5), detector, dateRange); + adTask.setTaskId(createADTask(adTask)); + return new ADBatchAnomalyResultRequest(adTask); + } + + private void testInvalidDetectionDateRange(DateRange dateRange) throws IOException, InterruptedException { + testInvalidDetectionDateRange(dateRange, "There is no data in the detection date range"); + } + + private void testInvalidDetectionDateRange(DateRange dateRange, String error) throws IOException, InterruptedException { + ADBatchAnomalyResultRequest request = adBatchAnomalyResultRequest(dateRange); + client().execute(ADBatchAnomalyResultAction.INSTANCE, request).actionGet(5000); + Thread.sleep(5000); + GetResponse doc = getDoc(ADCommonName.DETECTION_STATE_INDEX, request.getAdTask().getTaskId()); + assertEquals(error, doc.getSourceAsMap().get(ADTask.ERROR_FIELD)); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ADCancelTaskNodeRequestTests.java b/src/test/java/org/opensearch/ad/transport/ADCancelTaskNodeRequestTests.java index 546628a86..da3c2a7db 100644 --- a/src/test/java/org/opensearch/ad/transport/ADCancelTaskNodeRequestTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADCancelTaskNodeRequestTests.java @@ -16,7 +16,7 @@ import org.opensearch.ad.ADUnitTestCase; import org.opensearch.ad.mock.transport.MockADCancelTaskNodeRequest_1_0; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamInput; public class ADCancelTaskNodeRequestTests extends ADUnitTestCase { diff --git a/src/test/java/org/opensearch/ad/transport/ADCancelTaskNodeRequestTests.java-e b/src/test/java/org/opensearch/ad/transport/ADCancelTaskNodeRequestTests.java-e new file mode 100644 index 000000000..da3c2a7db --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ADCancelTaskNodeRequestTests.java-e @@ -0,0 +1,36 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.opensearch.ad.ADUnitTestCase; +import org.opensearch.ad.mock.transport.MockADCancelTaskNodeRequest_1_0; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; + +public class ADCancelTaskNodeRequestTests extends ADUnitTestCase { + + public void testParseOldADCancelTaskNodeRequestTest() throws IOException { + String detectorId = randomAlphaOfLength(5); + String userName = randomAlphaOfLength(5); + MockADCancelTaskNodeRequest_1_0 oldRequest = new MockADCancelTaskNodeRequest_1_0(detectorId, userName); + BytesStreamOutput output = new BytesStreamOutput(); + oldRequest.writeTo(output); + StreamInput input = output.bytes().streamInput(); + ADCancelTaskNodeRequest parsedRequest = new ADCancelTaskNodeRequest(input); + assertEquals(detectorId, parsedRequest.getId()); + assertEquals(userName, parsedRequest.getUserName()); + assertNull(parsedRequest.getDetectorTaskId()); + assertNull(parsedRequest.getReason()); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ADCancelTaskTests.java b/src/test/java/org/opensearch/ad/transport/ADCancelTaskTests.java index 85d839a1a..fa0f857d6 100644 --- a/src/test/java/org/opensearch/ad/transport/ADCancelTaskTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADCancelTaskTests.java @@ -22,8 +22,8 @@ import org.opensearch.ad.task.ADTaskCancellationState; import org.opensearch.cluster.ClusterName; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; -import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.StreamInput; import com.google.common.collect.ImmutableList; diff --git a/src/test/java/org/opensearch/ad/transport/ADCancelTaskTests.java-e b/src/test/java/org/opensearch/ad/transport/ADCancelTaskTests.java-e new file mode 100644 index 000000000..fa0f857d6 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ADCancelTaskTests.java-e @@ -0,0 +1,77 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.timeseries.TestHelpers.randomDiscoveryNode; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.ADUnitTestCase; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.task.ADTaskCancellationState; +import org.opensearch.cluster.ClusterName; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.StreamInput; + +import com.google.common.collect.ImmutableList; + +public class ADCancelTaskTests extends ADUnitTestCase { + + public void testADCancelTaskRequest() throws IOException { + ADCancelTaskRequest request = new ADCancelTaskRequest( + randomAlphaOfLength(5), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + randomDiscoveryNode() + ); + + BytesStreamOutput output = new BytesStreamOutput(); + request.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + ADCancelTaskRequest parsedRequest = new ADCancelTaskRequest(input); + assertEquals(request.getId(), parsedRequest.getId()); + assertEquals(request.getUserName(), parsedRequest.getUserName()); + } + + public void testInvalidADCancelTaskRequest() { + ADCancelTaskRequest request = new ADCancelTaskRequest(null, null, null, randomDiscoveryNode()); + ActionRequestValidationException validationException = request.validate(); + assertTrue(validationException.getMessage().contains(ADCommonMessages.AD_ID_MISSING_MSG)); + } + + public void testSerializeResponse() throws IOException { + ADTaskCancellationState state = ADTaskCancellationState.CANCELLED; + ADCancelTaskNodeResponse nodeResponse = new ADCancelTaskNodeResponse(randomDiscoveryNode(), state); + + List nodes = ImmutableList.of(nodeResponse); + ADCancelTaskResponse response = new ADCancelTaskResponse(new ClusterName("test"), nodes, ImmutableList.of()); + + BytesStreamOutput output = new BytesStreamOutput(); + response.writeNodesTo(output, nodes); + StreamInput input = output.bytes().streamInput(); + + List adCancelTaskNodeResponses = response.readNodesFrom(input); + assertEquals(1, adCancelTaskNodeResponses.size()); + assertEquals(state, adCancelTaskNodeResponses.get(0).getState()); + + BytesStreamOutput output2 = new BytesStreamOutput(); + response.writeTo(output2); + StreamInput input2 = output2.bytes().streamInput(); + + ADCancelTaskResponse response2 = new ADCancelTaskResponse(input2); + assertEquals(1, response2.getNodes().size()); + assertEquals(state, response2.getNodes().get(0).getState()); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ADResultBulkResponseTests.java b/src/test/java/org/opensearch/ad/transport/ADResultBulkResponseTests.java index 7d9308bb7..6946953fc 100644 --- a/src/test/java/org/opensearch/ad/transport/ADResultBulkResponseTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADResultBulkResponseTests.java @@ -18,7 +18,7 @@ import org.opensearch.action.index.IndexRequest; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.test.OpenSearchTestCase; public class ADResultBulkResponseTests extends OpenSearchTestCase { diff --git a/src/test/java/org/opensearch/ad/transport/ADResultBulkResponseTests.java-e b/src/test/java/org/opensearch/ad/transport/ADResultBulkResponseTests.java-e new file mode 100644 index 000000000..6946953fc --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ADResultBulkResponseTests.java-e @@ -0,0 +1,35 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.opensearch.action.index.IndexRequest; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.test.OpenSearchTestCase; + +public class ADResultBulkResponseTests extends OpenSearchTestCase { + public void testSerialization() throws IOException { + BytesStreamOutput output = new BytesStreamOutput(); + List retryRequests = new ArrayList<>(); + retryRequests.add(new IndexRequest("index").id("blah").source(Collections.singletonMap("foo", "bar"))); + ADResultBulkResponse response = new ADResultBulkResponse(retryRequests); + response.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + ADResultBulkResponse readResponse = new ADResultBulkResponse(streamInput); + assertTrue(readResponse.hasFailures()); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ADResultBulkTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/ADResultBulkTransportActionTests.java index 83e83b25c..432849b82 100644 --- a/src/test/java/org/opensearch/ad/transport/ADResultBulkTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADResultBulkTransportActionTests.java @@ -36,8 +36,8 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.index.IndexingPressure; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; diff --git a/src/test/java/org/opensearch/ad/transport/ADResultBulkTransportActionTests.java-e b/src/test/java/org/opensearch/ad/transport/ADResultBulkTransportActionTests.java-e new file mode 100644 index 000000000..397effff8 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ADResultBulkTransportActionTests.java-e @@ -0,0 +1,215 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasItem; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Locale; + +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.index.IndexingPressure; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.transport.TransportService; + +public class ADResultBulkTransportActionTests extends AbstractTimeSeriesTest { + private ADResultBulkTransportAction resultBulk; + private TransportService transportService; + private ClusterService clusterService; + private IndexingPressure indexingPressure; + private Client client; + private String detectorId; + + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(AnomalyResultTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + Settings settings = Settings + .builder() + .put(IndexingPressure.MAX_INDEXING_BYTES.getKey(), "1KB") + .put(AnomalyDetectorSettings.AD_INDEX_PRESSURE_SOFT_LIMIT.getKey(), 0.8) + .build(); + + // without register these settings, the constructor of ADResultBulkTransportAction cannot invoke update consumer + setupTestNodes(AnomalyDetectorSettings.AD_INDEX_PRESSURE_SOFT_LIMIT, AnomalyDetectorSettings.AD_INDEX_PRESSURE_HARD_LIMIT); + transportService = testNodes[0].transportService; + clusterService = testNodes[0].clusterService; + + ActionFilters actionFilters = mock(ActionFilters.class); + indexingPressure = mock(IndexingPressure.class); + + client = mock(Client.class); + detectorId = randomAlphaOfLength(5); + + resultBulk = new ADResultBulkTransportAction(transportService, actionFilters, indexingPressure, settings, clusterService, client); + } + + @Override + @After + public final void tearDown() throws Exception { + tearDownTestNodes(); + super.tearDown(); + } + + @SuppressWarnings("unchecked") + public void testSendAll() { + when(indexingPressure.getCurrentCombinedCoordinatingAndPrimaryBytes()).thenReturn(0L); + when(indexingPressure.getCurrentReplicaBytes()).thenReturn(0L); + + ADResultBulkRequest originalRequest = new ADResultBulkRequest(); + originalRequest.add(TestHelpers.randomResultWriteRequest(detectorId, 0.8d, 0d)); + originalRequest.add(TestHelpers.randomResultWriteRequest(detectorId, 8d, 0.2d)); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length == 3 + ); + + assertTrue(args[1] instanceof BulkRequest); + assertTrue(args[2] instanceof ActionListener); + BulkRequest request = (BulkRequest) args[1]; + ActionListener listener = (ActionListener) args[2]; + + assertEquals(2, request.requests().size()); + listener.onResponse(null); + return null; + }).when(client).execute(any(), any(), any()); + + PlainActionFuture future = PlainActionFuture.newFuture(); + resultBulk.doExecute(null, originalRequest, future); + + future.actionGet(); + } + + @SuppressWarnings("unchecked") + public void testSendPartial() { + // the limit is 1024 Bytes + when(indexingPressure.getCurrentCombinedCoordinatingAndPrimaryBytes()).thenReturn(1000L); + when(indexingPressure.getCurrentReplicaBytes()).thenReturn(24L); + + ADResultBulkRequest originalRequest = new ADResultBulkRequest(); + originalRequest.add(TestHelpers.randomResultWriteRequest(detectorId, 0.8d, 0d)); + originalRequest.add(TestHelpers.randomResultWriteRequest(detectorId, 8d, 0.2d)); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length == 3 + ); + + assertTrue(args[1] instanceof BulkRequest); + assertTrue(args[2] instanceof ActionListener); + BulkRequest request = (BulkRequest) args[1]; + ActionListener listener = (ActionListener) args[2]; + + assertEquals(1, request.requests().size()); + listener.onResponse(null); + return null; + }).when(client).execute(any(), any(), any()); + + PlainActionFuture future = PlainActionFuture.newFuture(); + resultBulk.doExecute(null, originalRequest, future); + + future.actionGet(); + } + + @SuppressWarnings("unchecked") + public void testSendRandomPartial() { + // 1024 * 0.9 > 400 + 421 > 1024 * 0.6. 1024 is 1KB, our INDEX_PRESSURE_SOFT_LIMIT + when(indexingPressure.getCurrentCombinedCoordinatingAndPrimaryBytes()).thenReturn(400L); + when(indexingPressure.getCurrentReplicaBytes()).thenReturn(421L); + + ADResultBulkRequest originalRequest = new ADResultBulkRequest(); + for (int i = 0; i < 1000; i++) { + originalRequest.add(TestHelpers.randomResultWriteRequest(detectorId, 0.8d, 0d)); + } + + originalRequest.add(TestHelpers.randomResultWriteRequest(detectorId, 8d, 0.2d)); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length == 3 + ); + + assertTrue(args[1] instanceof BulkRequest); + assertTrue(args[2] instanceof ActionListener); + BulkRequest request = (BulkRequest) args[1]; + ActionListener listener = (ActionListener) args[2]; + + int size = request.requests().size(); + assertTrue(1 < size); + // at least 1 half should be removed + assertTrue(String.format(Locale.ROOT, "size is actually %d", size), size < 500); + listener.onResponse(null); + return null; + }).when(client).execute(any(), any(), any()); + + PlainActionFuture future = PlainActionFuture.newFuture(); + resultBulk.doExecute(null, originalRequest, future); + + future.actionGet(); + } + + public void testSerialzationRequest() throws IOException { + ADResultBulkRequest request = new ADResultBulkRequest(); + request.add(TestHelpers.randomResultWriteRequest(detectorId, 0.8d, 0d)); + request.add(TestHelpers.randomResultWriteRequest(detectorId, 8d, 0.2d)); + BytesStreamOutput output = new BytesStreamOutput(); + request.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + ADResultBulkRequest readRequest = new ADResultBulkRequest(streamInput); + assertThat(2, equalTo(readRequest.numberOfActions())); + } + + public void testValidateRequest() { + ActionRequestValidationException e = new ADResultBulkRequest().validate(); + assertThat(e.validationErrors(), hasItem(ADResultBulkRequest.NO_REQUESTS_ADDED_ERR)); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ADStatsITTests.java b/src/test/java/org/opensearch/ad/transport/ADStatsITTests.java index 57da69965..da8f3dce7 100644 --- a/src/test/java/org/opensearch/ad/transport/ADStatsITTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADStatsITTests.java @@ -15,19 +15,19 @@ import java.util.Collections; import java.util.concurrent.ExecutionException; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; public class ADStatsITTests extends OpenSearchIntegTestCase { @Override protected Collection> nodePlugins() { - return Collections.singletonList(AnomalyDetectorPlugin.class); + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); } protected Collection> transportClientPlugins() { - return Collections.singletonList(AnomalyDetectorPlugin.class); + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); } public void testNormalADStats() throws ExecutionException, InterruptedException { diff --git a/src/test/java/org/opensearch/ad/transport/ADStatsITTests.java-e b/src/test/java/org/opensearch/ad/transport/ADStatsITTests.java-e new file mode 100644 index 000000000..da8f3dce7 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ADStatsITTests.java-e @@ -0,0 +1,39 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.util.Collection; +import java.util.Collections; +import java.util.concurrent.ExecutionException; + +import org.opensearch.plugins.Plugin; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +public class ADStatsITTests extends OpenSearchIntegTestCase { + + @Override + protected Collection> nodePlugins() { + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); + } + + protected Collection> transportClientPlugins() { + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); + } + + public void testNormalADStats() throws ExecutionException, InterruptedException { + ADStatsRequest adStatsRequest = new ADStatsRequest(new String[0]); + + ADStatsNodesResponse response = client().execute(ADStatsNodesAction.INSTANCE, adStatsRequest).get(); + assertTrue("getting stats failed", !response.hasFailures()); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ADStatsNodesTransportActionTests.java-e b/src/test/java/org/opensearch/ad/transport/ADStatsNodesTransportActionTests.java-e new file mode 100644 index 000000000..95799f911 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ADStatsNodesTransportActionTests.java-e @@ -0,0 +1,182 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_MODEL_SIZE_PER_NODE; + +import java.time.Clock; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.stats.ADStat; +import org.opensearch.ad.stats.ADStats; +import org.opensearch.ad.stats.InternalStatNames; +import org.opensearch.ad.stats.suppliers.CounterSupplier; +import org.opensearch.ad.stats.suppliers.IndexStatusSupplier; +import org.opensearch.ad.stats.suppliers.ModelsOnNodeSupplier; +import org.opensearch.ad.stats.suppliers.SettableSupplier; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.ad.util.IndexUtils; +import org.opensearch.ad.util.Throttler; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.monitor.jvm.JvmService; +import org.opensearch.monitor.jvm.JvmStats; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class ADStatsNodesTransportActionTests extends OpenSearchIntegTestCase { + + private ADStatsNodesTransportAction action; + private ADStats adStats; + private Map> statsMap; + private String clusterStatName1, clusterStatName2; + private String nodeStatName1, nodeStatName2; + private ADTaskManager adTaskManager; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + + Client client = client(); + Clock clock = mock(Clock.class); + Throttler throttler = new Throttler(clock); + ThreadPool threadPool = mock(ThreadPool.class); + IndexNameExpressionResolver indexNameResolver = mock(IndexNameExpressionResolver.class); + IndexUtils indexUtils = new IndexUtils( + client, + new ClientUtil(Settings.EMPTY, client, throttler, threadPool), + clusterService(), + indexNameResolver + ); + ModelManager modelManager = mock(ModelManager.class); + CacheProvider cacheProvider = mock(CacheProvider.class); + EntityCache cache = mock(EntityCache.class); + when(cacheProvider.get()).thenReturn(cache); + + clusterStatName1 = "clusterStat1"; + clusterStatName2 = "clusterStat2"; + nodeStatName1 = "nodeStat1"; + nodeStatName2 = "nodeStat2"; + + Settings settings = Settings.builder().put(MAX_MODEL_SIZE_PER_NODE.getKey(), 10).build(); + ClusterService clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(MAX_MODEL_SIZE_PER_NODE))) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + statsMap = new HashMap>() { + { + put(nodeStatName1, new ADStat<>(false, new CounterSupplier())); + put(nodeStatName2, new ADStat<>(false, new ModelsOnNodeSupplier(modelManager, cacheProvider, settings, clusterService))); + put(clusterStatName1, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index1"))); + put(clusterStatName2, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index2"))); + put(InternalStatNames.JVM_HEAP_USAGE.getName(), new ADStat<>(true, new SettableSupplier())); + } + }; + + adStats = new ADStats(statsMap); + JvmService jvmService = mock(JvmService.class); + JvmStats jvmStats = mock(JvmStats.class); + JvmStats.Mem mem = mock(JvmStats.Mem.class); + + when(jvmService.stats()).thenReturn(jvmStats); + when(jvmStats.getMem()).thenReturn(mem); + when(mem.getHeapUsedPercent()).thenReturn(randomShort()); + + adTaskManager = mock(ADTaskManager.class); + action = new ADStatsNodesTransportAction( + client().threadPool(), + clusterService(), + mock(TransportService.class), + mock(ActionFilters.class), + adStats, + jvmService, + adTaskManager + ); + } + + @Test + public void testNewNodeRequest() { + String nodeId = "nodeId1"; + ADStatsRequest adStatsRequest = new ADStatsRequest(nodeId); + + ADStatsNodeRequest adStatsNodeRequest1 = new ADStatsNodeRequest(adStatsRequest); + ADStatsNodeRequest adStatsNodeRequest2 = action.newNodeRequest(adStatsRequest); + + assertEquals(adStatsNodeRequest1.getADStatsRequest(), adStatsNodeRequest2.getADStatsRequest()); + } + + @Test + public void testNodeOperation() { + String nodeId = clusterService().localNode().getId(); + ADStatsRequest adStatsRequest = new ADStatsRequest((nodeId)); + adStatsRequest.clear(); + + Set statsToBeRetrieved = new HashSet<>(Arrays.asList(nodeStatName1, nodeStatName2)); + + for (String stat : statsToBeRetrieved) { + adStatsRequest.addStat(stat); + } + + ADStatsNodeResponse response = action.nodeOperation(new ADStatsNodeRequest(adStatsRequest)); + + Map stats = response.getStatsMap(); + + assertEquals(statsToBeRetrieved.size(), stats.size()); + for (String statName : stats.keySet()) { + assertTrue(statsToBeRetrieved.contains(statName)); + } + } + + @Test + public void testNodeOperationWithJvmHeapUsage() { + String nodeId = clusterService().localNode().getId(); + ADStatsRequest adStatsRequest = new ADStatsRequest((nodeId)); + adStatsRequest.clear(); + + Set statsToBeRetrieved = new HashSet<>(Arrays.asList(nodeStatName1, InternalStatNames.JVM_HEAP_USAGE.getName())); + + for (String stat : statsToBeRetrieved) { + adStatsRequest.addStat(stat); + } + + ADStatsNodeResponse response = action.nodeOperation(new ADStatsNodeRequest(adStatsRequest)); + + Map stats = response.getStatsMap(); + + assertEquals(statsToBeRetrieved.size(), stats.size()); + for (String statName : stats.keySet()) { + assertTrue(statsToBeRetrieved.contains(statName)); + } + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ADStatsTests.java b/src/test/java/org/opensearch/ad/transport/ADStatsTests.java index 317dd8dae..b9595e2e7 100644 --- a/src/test/java/org/opensearch/ad/transport/ADStatsTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADStatsTests.java @@ -40,8 +40,8 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.transport.TransportAddress; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.test.OpenSearchTestCase; diff --git a/src/test/java/org/opensearch/ad/transport/ADStatsTests.java-e b/src/test/java/org/opensearch/ad/transport/ADStatsTests.java-e new file mode 100644 index 000000000..ba6fac1e2 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ADStatsTests.java-e @@ -0,0 +1,266 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; + +import java.io.IOException; +import java.time.Clock; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import java.util.stream.Collectors; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.Version; +import org.opensearch.action.FailedNodeException; +import org.opensearch.ad.common.exception.JsonPathNotFoundException; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.Strings; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.transport.TransportAddress; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.stats.StatNames; + +import test.org.opensearch.ad.util.JsonDeserializer; + +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; + +public class ADStatsTests extends OpenSearchTestCase { + String node1, nodeName1, clusterName; + Map clusterStats; + DiscoveryNode discoveryNode1; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + node1 = "node1"; + nodeName1 = "nodename1"; + clusterName = "test-cluster-name"; + discoveryNode1 = new DiscoveryNode( + nodeName1, + node1, + new TransportAddress(TransportAddress.META_ADDRESS, 9300), + emptyMap(), + emptySet(), + Version.CURRENT + ); + clusterStats = new HashMap<>(); + } + + @Test + public void testADStatsNodeRequest() throws IOException { + ADStatsNodeRequest adStatsNodeRequest1 = new ADStatsNodeRequest(); + assertNull("ADStatsNodeRequest default constructor failed", adStatsNodeRequest1.getADStatsRequest()); + + ADStatsRequest adStatsRequest = new ADStatsRequest(new String[0]); + ADStatsNodeRequest adStatsNodeRequest2 = new ADStatsNodeRequest(adStatsRequest); + assertEquals("ADStatsNodeRequest has the wrong ADStatsRequest", adStatsNodeRequest2.getADStatsRequest(), adStatsRequest); + + // Test serialization + BytesStreamOutput output = new BytesStreamOutput(); + adStatsNodeRequest2.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + adStatsNodeRequest1 = new ADStatsNodeRequest(streamInput); + assertEquals( + "readStats failed", + adStatsNodeRequest2.getADStatsRequest().getStatsToBeRetrieved(), + adStatsNodeRequest1.getADStatsRequest().getStatsToBeRetrieved() + ); + } + + @Test + public void testSimpleADStatsNodeResponse() throws IOException, JsonPathNotFoundException { + Map stats = new HashMap() { + { + put("testKey", "testValue"); + } + }; + + // Test serialization + ADStatsNodeResponse adStatsNodeResponse = new ADStatsNodeResponse(discoveryNode1, stats); + BytesStreamOutput output = new BytesStreamOutput(); + adStatsNodeResponse.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + ADStatsNodeResponse readResponse = ADStatsNodeResponse.readStats(streamInput); + assertEquals("readStats failed", readResponse.getStatsMap(), adStatsNodeResponse.getStatsMap()); + + // Test toXContent + XContentBuilder builder = jsonBuilder(); + adStatsNodeResponse.toXContent(builder.startObject(), ToXContent.EMPTY_PARAMS).endObject(); + String json = Strings.toString(builder); + + for (Map.Entry stat : stats.entrySet()) { + assertEquals("toXContent does not work", JsonDeserializer.getTextValue(json, stat.getKey()), stat.getValue()); + } + } + + /** + * Test we can serialize stats with entity + * @throws IOException when writeTo and toXContent have errors. + * @throws JsonPathNotFoundException when json deserialization cannot find a path + */ + @Test + public void testADStatsNodeResponseWithEntity() throws IOException, JsonPathNotFoundException { + TreeMap attributes = new TreeMap<>(); + String name1 = "a"; + String name2 = "b"; + String val1 = "a1"; + String val2 = "a2"; + attributes.put(name1, val1); + attributes.put(name2, val2); + String detectorId = "detectorId"; + Entity entity = Entity.createEntityFromOrderedMap(attributes); + EntityModel entityModel = new EntityModel(entity, null, null); + Clock clock = mock(Clock.class); + when(clock.instant()).thenReturn(Instant.now()); + ModelState state = new ModelState( + entityModel, + entity.getModelId(detectorId).get(), + detectorId, + "entity", + clock, + 0.1f + ); + Map stats = state.getModelStateAsMap(); + + // Test serialization + ADStatsNodeResponse adStatsNodeResponse = new ADStatsNodeResponse(discoveryNode1, stats); + BytesStreamOutput output = new BytesStreamOutput(); + adStatsNodeResponse.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + ADStatsNodeResponse readResponse = ADStatsNodeResponse.readStats(streamInput); + assertEquals("readStats failed", readResponse.getStatsMap(), adStatsNodeResponse.getStatsMap()); + + // Test toXContent + XContentBuilder builder = jsonBuilder(); + adStatsNodeResponse.toXContent(builder.startObject(), ToXContent.EMPTY_PARAMS).endObject(); + String json = Strings.toString(builder); + + for (Map.Entry stat : stats.entrySet()) { + if (stat.getKey().equals(ModelState.LAST_CHECKPOINT_TIME_KEY) || stat.getKey().equals(ModelState.LAST_USED_TIME_KEY)) { + assertEquals("toXContent does not work", JsonDeserializer.getLongValue(json, stat.getKey()), stat.getValue()); + } else if (stat.getKey().equals(CommonName.ENTITY_KEY)) { + JsonArray array = JsonDeserializer.getArrayValue(json, stat.getKey()); + assertEquals(2, array.size()); + for (int i = 0; i < 2; i++) { + JsonElement element = array.get(i); + String entityName = JsonDeserializer.getChildNode(element, Entity.ATTRIBUTE_NAME_FIELD).getAsString(); + String entityValue = JsonDeserializer.getChildNode(element, Entity.ATTRIBUTE_VALUE_FIELD).getAsString(); + + assertTrue(entityName.equals(name1) || entityName.equals(name2)); + if (entityName.equals(name1)) { + assertEquals(val1, entityValue); + } else { + assertEquals(val2, entityValue); + } + } + } else { + assertEquals("toXContent does not work", JsonDeserializer.getTextValue(json, stat.getKey()), stat.getValue()); + } + } + } + + @Test + public void testADStatsRequest() throws IOException { + List allStats = Arrays.stream(StatNames.values()).map(StatNames::getName).collect(Collectors.toList()); + ADStatsRequest adStatsRequest = new ADStatsRequest(new String[0]); + + // Test clear() + adStatsRequest.clear(); + for (String stat : allStats) { + assertTrue("clear() fails", !adStatsRequest.getStatsToBeRetrieved().contains(stat)); + } + + // Test all() + adStatsRequest.addAll(new HashSet<>(allStats)); + for (String stat : allStats) { + assertTrue("all() fails", adStatsRequest.getStatsToBeRetrieved().contains(stat)); + } + + // Test add stat + adStatsRequest.clear(); + adStatsRequest.addStat(StatNames.AD_EXECUTE_REQUEST_COUNT.getName()); + assertTrue("addStat fails", adStatsRequest.getStatsToBeRetrieved().contains(StatNames.AD_EXECUTE_REQUEST_COUNT.getName())); + + // Test Serialization + BytesStreamOutput output = new BytesStreamOutput(); + adStatsRequest.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + ADStatsRequest readRequest = new ADStatsRequest(streamInput); + assertEquals("Serialization fails", readRequest.getStatsToBeRetrieved(), adStatsRequest.getStatsToBeRetrieved()); + } + + @Test + public void testADStatsNodesResponse() throws IOException, JsonPathNotFoundException { + Map nodeStats = new HashMap() { + { + put("testNodeKey", "testNodeValue"); + } + }; + + ADStatsNodeResponse adStatsNodeResponse = new ADStatsNodeResponse(discoveryNode1, nodeStats); + List adStatsNodeResponses = Collections.singletonList(adStatsNodeResponse); + List failures = Collections.emptyList(); + ADStatsNodesResponse adStatsNodesResponse = new ADStatsNodesResponse(new ClusterName(clusterName), adStatsNodeResponses, failures); + + // Test toXContent + XContentBuilder builder = jsonBuilder(); + adStatsNodesResponse.toXContent(builder.startObject(), ToXContent.EMPTY_PARAMS).endObject(); + String json = Strings.toString(builder); + + logger.info("JSON: " + json); + + // nodeStats + String nodesJson = JsonDeserializer.getChildNode(json, "nodes").toString(); + String node1Json = JsonDeserializer.getChildNode(nodesJson, node1).toString(); + + for (Map.Entry stat : nodeStats.entrySet()) { + assertEquals( + "toXContent does not work for node stats", + JsonDeserializer.getTextValue(node1Json, stat.getKey()), + stat.getValue() + ); + } + + // Test Serialization + BytesStreamOutput output = new BytesStreamOutput(); + + adStatsNodesResponse.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + ADStatsNodesResponse readRequest = new ADStatsNodesResponse(streamInput); + + builder = jsonBuilder(); + String readJson = Strings.toString(readRequest.toXContent(builder.startObject(), ToXContent.EMPTY_PARAMS).endObject()); + assertEquals("Serialization fails", readJson, json); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ADTaskProfileResponseTests.java b/src/test/java/org/opensearch/ad/transport/ADTaskProfileResponseTests.java index 1807a6497..2f6d555d1 100644 --- a/src/test/java/org/opensearch/ad/transport/ADTaskProfileResponseTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADTaskProfileResponseTests.java @@ -21,7 +21,7 @@ import org.opensearch.ad.model.ADTaskProfile; import org.opensearch.cluster.ClusterName; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamInput; import com.google.common.collect.ImmutableList; diff --git a/src/test/java/org/opensearch/ad/transport/ADTaskProfileResponseTests.java-e b/src/test/java/org/opensearch/ad/transport/ADTaskProfileResponseTests.java-e new file mode 100644 index 000000000..2f6d555d1 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ADTaskProfileResponseTests.java-e @@ -0,0 +1,56 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.timeseries.TestHelpers.randomDiscoveryNode; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.Version; +import org.opensearch.ad.ADUnitTestCase; +import org.opensearch.ad.model.ADTaskProfile; +import org.opensearch.cluster.ClusterName; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; + +import com.google.common.collect.ImmutableList; + +public class ADTaskProfileResponseTests extends ADUnitTestCase { + + public void testSerializeResponse() throws IOException { + String taskId = randomAlphaOfLength(5); + ADTaskProfile adTaskProfile = new ADTaskProfile(); + adTaskProfile.setTaskId(taskId); + Version remoteAdVersion = Version.CURRENT; + ADTaskProfileNodeResponse nodeResponse = new ADTaskProfileNodeResponse(randomDiscoveryNode(), adTaskProfile, remoteAdVersion); + + List nodeResponses = ImmutableList.of(nodeResponse); + ADTaskProfileResponse response = new ADTaskProfileResponse(new ClusterName("test"), nodeResponses, ImmutableList.of()); + + BytesStreamOutput output = new BytesStreamOutput(); + response.writeNodesTo(output, nodeResponses); + StreamInput input = output.bytes().streamInput(); + + List adTaskProfileNodeResponses = response.readNodesFrom(input); + assertEquals(1, adTaskProfileNodeResponses.size()); + assertEquals(taskId, adTaskProfileNodeResponses.get(0).getAdTaskProfile().getTaskId()); + + BytesStreamOutput output2 = new BytesStreamOutput(); + response.writeTo(output2); + StreamInput input2 = output2.bytes().streamInput(); + + ADTaskProfileResponse response2 = new ADTaskProfileResponse(input2); + assertEquals(1, response2.getNodes().size()); + assertEquals(taskId, response2.getNodes().get(0).getAdTaskProfile().getTaskId()); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ADTaskProfileTests.java b/src/test/java/org/opensearch/ad/transport/ADTaskProfileTests.java index 49e4172ae..d8e13e9ec 100644 --- a/src/test/java/org/opensearch/ad/transport/ADTaskProfileTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADTaskProfileTests.java @@ -21,27 +21,27 @@ import org.junit.Ignore; import org.opensearch.Version; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.model.ADTaskProfile; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.UUIDs; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; -import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.plugins.Plugin; import org.opensearch.test.InternalSettingsPlugin; import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import com.google.common.collect.ImmutableList; public class ADTaskProfileTests extends OpenSearchSingleNodeTestCase { @Override protected Collection> getPlugins() { - return pluginList(InternalSettingsPlugin.class, AnomalyDetectorPlugin.class); + return pluginList(InternalSettingsPlugin.class, TimeSeriesAnalyticsPlugin.class); } @Override diff --git a/src/test/java/org/opensearch/ad/transport/ADTaskProfileTests.java-e b/src/test/java/org/opensearch/ad/transport/ADTaskProfileTests.java-e new file mode 100644 index 000000000..d8e13e9ec --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ADTaskProfileTests.java-e @@ -0,0 +1,196 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.timeseries.TestHelpers.randomDiscoveryNode; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collection; +import java.util.List; + +import org.junit.Ignore; +import org.opensearch.Version; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.model.ADTaskProfile; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.UUIDs; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.plugins.Plugin; +import org.opensearch.test.InternalSettingsPlugin; +import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +import com.google.common.collect.ImmutableList; + +public class ADTaskProfileTests extends OpenSearchSingleNodeTestCase { + @Override + protected Collection> getPlugins() { + return pluginList(InternalSettingsPlugin.class, TimeSeriesAnalyticsPlugin.class); + } + + @Override + protected NamedWriteableRegistry writableRegistry() { + return getInstanceFromNode(NamedWriteableRegistry.class); + } + + public void testADTaskProfileRequest() throws IOException { + ADTaskProfileRequest request = new ADTaskProfileRequest(randomAlphaOfLength(5), randomDiscoveryNode()); + + BytesStreamOutput output = new BytesStreamOutput(); + request.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + ADTaskProfileRequest parsedRequest = new ADTaskProfileRequest(input); + assertEquals(request.getId(), parsedRequest.getId()); + } + + public void testInvalidADTaskProfileRequest() { + DiscoveryNode node = new DiscoveryNode(UUIDs.randomBase64UUID(), buildNewFakeTransportAddress(), Version.CURRENT); + ADTaskProfileRequest request = new ADTaskProfileRequest(null, node); + ActionRequestValidationException validationException = request.validate(); + assertTrue(validationException.getMessage().contains(ADCommonMessages.AD_ID_MISSING_MSG)); + } + + public void testADTaskProfileNodeResponse() throws IOException { + ADTaskProfile adTaskProfile = new ADTaskProfile( + randomAlphaOfLength(5), + randomInt(), + randomLong(), + randomBoolean(), + randomInt(), + randomLong(), + randomAlphaOfLength(5) + ); + ADTaskProfileNodeResponse response = new ADTaskProfileNodeResponse(randomDiscoveryNode(), adTaskProfile, Version.CURRENT); + testADTaskProfileResponse(response); + } + + public void testADTaskProfileNodeResponseWithNullProfile() throws IOException { + ADTaskProfileNodeResponse response = new ADTaskProfileNodeResponse(randomDiscoveryNode(), null, Version.CURRENT); + testADTaskProfileResponse(response); + } + + public void testADTaskProfileNodeResponseReadMethod() throws IOException { + ADTaskProfile adTaskProfile = new ADTaskProfile( + randomAlphaOfLength(5), + randomInt(), + randomLong(), + randomBoolean(), + randomInt(), + randomLong(), + randomAlphaOfLength(5) + ); + ADTaskProfileNodeResponse response = new ADTaskProfileNodeResponse(randomDiscoveryNode(), adTaskProfile, Version.CURRENT); + testADTaskProfileResponse(response); + } + + public void testADTaskProfileNodeResponseReadMethodWithNullProfile() throws IOException { + ADTaskProfileNodeResponse response = new ADTaskProfileNodeResponse(randomDiscoveryNode(), null, Version.CURRENT); + testADTaskProfileResponse(response); + } + + private void testADTaskProfileResponse(ADTaskProfileNodeResponse response) throws IOException { + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + ADTaskProfileNodeResponse parsedResponse = ADTaskProfileNodeResponse.readNodeResponse(input); + if (response.getAdTaskProfile() != null) { + assertTrue(response.getAdTaskProfile().equals(parsedResponse.getAdTaskProfile())); + } else { + assertNull(parsedResponse.getAdTaskProfile()); + } + } + + public void testADTaskProfileParse() throws IOException { + ADTaskProfile adTaskProfile = new ADTaskProfile( + randomAlphaOfLength(5), + randomInt(), + randomLong(), + randomBoolean(), + randomInt(), + randomLong(), + randomAlphaOfLength(5) + ); + String adTaskProfileString = TestHelpers + .xContentBuilderToString(adTaskProfile.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + ADTaskProfile parsedADTaskProfile = ADTaskProfile.parse(TestHelpers.parser(adTaskProfileString)); + assertEquals(adTaskProfile, parsedADTaskProfile); + assertEquals(parsedADTaskProfile.toString(), adTaskProfile.toString()); + } + + @Ignore + public void testSerializeResponse() throws IOException { + DiscoveryNode node = randomDiscoveryNode(); + ADTaskProfile profile = new ADTaskProfile( + TestHelpers.randomAdTask(), + randomInt(), + randomLong(), + randomBoolean(), + randomInt(), + randomLong(), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + randomInt(), + randomBoolean(), + randomInt(), + randomInt(), + randomInt(), + ImmutableList.of(randomAlphaOfLength(5)), + Instant.now().toEpochMilli() + ); + ADTaskProfileNodeResponse nodeResponse = new ADTaskProfileNodeResponse(node, profile, Version.CURRENT); + ImmutableList nodes = ImmutableList.of(nodeResponse); + ADTaskProfileResponse response = new ADTaskProfileResponse(new ClusterName("test"), nodes, ImmutableList.of()); + + BytesStreamOutput output = new BytesStreamOutput(); + response.writeNodesTo(output, nodes); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + + List adTaskProfileNodeResponses = response.readNodesFrom(input); + assertEquals(1, adTaskProfileNodeResponses.size()); + ADTaskProfileNodeResponse parsedProfile = adTaskProfileNodeResponses.get(0); + + assertEquals(profile.getTaskId(), parsedProfile.getAdTaskProfile().getTaskId()); + } + + public void testADTaskProfileParseFullConstructor() throws IOException { + ADTaskProfile adTaskProfile = new ADTaskProfile( + TestHelpers.randomAdTask(), + randomInt(), + randomLong(), + randomBoolean(), + randomInt(), + randomLong(), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + randomInt(), + randomBoolean(), + randomInt(), + randomInt(), + randomInt(), + ImmutableList.of(randomAlphaOfLength(5)), + Instant.now().toEpochMilli() + ); + String adTaskProfileString = TestHelpers + .xContentBuilderToString(adTaskProfile.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + ADTaskProfile parsedADTaskProfile = ADTaskProfile.parse(TestHelpers.parser(adTaskProfileString)); + assertEquals(adTaskProfile, parsedADTaskProfile); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ADTaskProfileTransportActionTests.java-e b/src/test/java/org/opensearch/ad/transport/ADTaskProfileTransportActionTests.java-e new file mode 100644 index 000000000..d57b5e1c5 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ADTaskProfileTransportActionTests.java-e @@ -0,0 +1,47 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE; + +import java.time.Instant; +import java.time.temporal.ChronoUnit; + +import org.junit.Before; +import org.opensearch.ad.HistoricalAnalysisIntegTestCase; +import org.opensearch.common.settings.Settings; + +public class ADTaskProfileTransportActionTests extends HistoricalAnalysisIntegTestCase { + + private Instant startTime; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + startTime = Instant.now().minus(10, ChronoUnit.DAYS); + ingestTestData(testIndex, startTime, detectionIntervalInMinutes, "error", 2000); + createDetectorIndex(); + } + + @Override + protected Settings nodeSettings(int nodeOrdinal) { + return Settings + .builder() + .put(super.nodeSettings(nodeOrdinal)) + .put(BATCH_TASK_PIECE_INTERVAL_SECONDS.getKey(), 1) + .put(MAX_BATCH_TASK_PER_NODE.getKey(), 1) + .build(); + } + +} diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobActionTests.java b/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobActionTests.java index afbae1eb1..30931af13 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobActionTests.java @@ -32,12 +32,12 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.ConfigConstants; -import org.opensearch.rest.RestStatus; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.rest.RestStatus; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.threadpool.ThreadPool; diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobActionTests.java-e b/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobActionTests.java-e new file mode 100644 index 000000000..f7cd836f2 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobActionTests.java-e @@ -0,0 +1,147 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.ExecuteADResultResponseRecorder; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.transport.TransportService; + +public class AnomalyDetectorJobActionTests extends OpenSearchIntegTestCase { + private AnomalyDetectorJobTransportAction action; + private Task task; + private AnomalyDetectorJobRequest request; + private ActionListener response; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + ClusterService clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES))) + ); + + Settings build = Settings.builder().build(); + ThreadContext threadContext = new ThreadContext(build); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alice|odfe,aes|engineering,operations"); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + Client client = mock(Client.class); + org.opensearch.threadpool.ThreadPool mockThreadPool = mock(ThreadPool.class); + when(client.threadPool()).thenReturn(mockThreadPool); + when(mockThreadPool.getThreadContext()).thenReturn(threadContext); + + action = new AnomalyDetectorJobTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + client, + clusterService, + indexSettings(), + mock(ADIndexManagement.class), + xContentRegistry(), + mock(ADTaskManager.class), + mock(ExecuteADResultResponseRecorder.class) + ); + task = mock(Task.class); + request = new AnomalyDetectorJobRequest("1234", 4567, 7890, "_start"); + response = new ActionListener() { + @Override + public void onResponse(AnomalyDetectorJobResponse adResponse) { + // Will not be called as there is no detector + Assert.assertTrue(false); + } + + @Override + public void onFailure(Exception e) { + // Will not be called as there is no detector + Assert.assertTrue(true); + } + }; + } + + @Test + public void testStartAdJobTransportAction() { + action.doExecute(task, request, response); + } + + @Test + public void testStopAdJobTransportAction() { + AnomalyDetectorJobRequest stopRequest = new AnomalyDetectorJobRequest("1234", 4567, 7890, "_stop"); + action.doExecute(task, stopRequest, response); + } + + @Test + public void testAdJobAction() { + Assert.assertNotNull(AnomalyDetectorJobAction.INSTANCE.name()); + Assert.assertEquals(AnomalyDetectorJobAction.INSTANCE.name(), AnomalyDetectorJobAction.NAME); + } + + @Test + public void testAdJobRequest() throws IOException { + DateRange detectionDateRange = new DateRange(Instant.MIN, Instant.now()); + request = new AnomalyDetectorJobRequest("1234", detectionDateRange, false, 4567, 7890, "_start"); + + BytesStreamOutput out = new BytesStreamOutput(); + request.writeTo(out); + StreamInput input = out.bytes().streamInput(); + AnomalyDetectorJobRequest newRequest = new AnomalyDetectorJobRequest(input); + Assert.assertEquals(request.getDetectorID(), newRequest.getDetectorID()); + } + + @Test + public void testAdJobRequest_NullDetectionDateRange() throws IOException { + BytesStreamOutput out = new BytesStreamOutput(); + request.writeTo(out); + StreamInput input = out.bytes().streamInput(); + AnomalyDetectorJobRequest newRequest = new AnomalyDetectorJobRequest(input); + Assert.assertEquals(request.getDetectorID(), newRequest.getDetectorID()); + } + + @Test + public void testAdJobResponse() throws IOException { + BytesStreamOutput out = new BytesStreamOutput(); + AnomalyDetectorJobResponse response = new AnomalyDetectorJobResponse("1234", 45, 67, 890, RestStatus.OK); + response.writeTo(out); + StreamInput input = out.bytes().streamInput(); + AnomalyDetectorJobResponse newResponse = new AnomalyDetectorJobResponse(input); + Assert.assertEquals(response.getId(), newResponse.getId()); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportActionTests.java-e b/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportActionTests.java-e new file mode 100644 index 000000000..daf86ab7c --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportActionTests.java-e @@ -0,0 +1,521 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.constant.ADCommonMessages.DETECTOR_IS_RUNNING; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR; +import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM; +import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; +import static org.opensearch.timeseries.TestHelpers.HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS; +import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; +import static org.opensearch.timeseries.util.RestHandlerUtils.PROFILE; +import static org.opensearch.timeseries.util.RestHandlerUtils.START_JOB; +import static org.opensearch.timeseries.util.RestHandlerUtils.STOP_JOB; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import org.junit.After; +import org.junit.Before; +import org.junit.Ignore; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetResponse; +import org.opensearch.ad.HistoricalAnalysisIntegTestCase; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.mock.model.MockSimpleLog; +import org.opensearch.ad.mock.transport.MockAnomalyDetectorJobAction; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskProfile; +import org.opensearch.ad.model.ADTaskState; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.client.Client; +import org.opensearch.common.lucene.uid.Versions; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.stats.StatNames; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 2) +public class AnomalyDetectorJobTransportActionTests extends HistoricalAnalysisIntegTestCase { + private Instant startTime; + private Instant endTime; + private String type = "error"; + private int maxOldAdTaskDocsPerDetector = 2; + private DateRange dateRange; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + startTime = Instant.now().minus(10, ChronoUnit.DAYS); + endTime = Instant.now(); + dateRange = new DateRange(startTime, endTime); + ingestTestData(testIndex, startTime, detectionIntervalInMinutes, type, 2000); + createDetectorIndex(); + } + + @Override + protected Settings nodeSettings(int nodeOrdinal) { + return Settings + .builder() + .put(super.nodeSettings(nodeOrdinal)) + .put(BATCH_TASK_PIECE_INTERVAL_SECONDS.getKey(), 1) + .put(MAX_BATCH_TASK_PER_NODE.getKey(), 1) + .put(MAX_OLD_AD_TASK_DOCS_PER_DETECTOR.getKey(), maxOldAdTaskDocsPerDetector) + .build(); + } + + public void testDetectorIndexNotFound() { + deleteDetectorIndex(); + String detectorId = randomAlphaOfLength(5); + AnomalyDetectorJobRequest request = startDetectorJobRequest(detectorId, dateRange); + IndexNotFoundException exception = expectThrows( + IndexNotFoundException.class, + () -> client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(3000) + ); + assertTrue(exception.getMessage().contains("no such index [.opendistro-anomaly-detectors]")); + } + + public void testDetectorNotFound() { + String detectorId = randomAlphaOfLength(5); + AnomalyDetectorJobRequest request = startDetectorJobRequest(detectorId, dateRange); + OpenSearchStatusException exception = expectThrows( + OpenSearchStatusException.class, + () -> client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000) + ); + assertTrue(exception.getMessage().contains(FAIL_TO_FIND_CONFIG_MSG)); + } + + public void testValidHistoricalAnalysis() throws IOException, InterruptedException { + ADTask adTask = startHistoricalAnalysis(startTime, endTime); + Thread.sleep(10000); + ADTask finishedTask = getADTask(adTask.getTaskId()); + assertTrue(HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(finishedTask.getState())); + } + + public void testStartHistoricalAnalysisWithUser() throws IOException { + AnomalyDetector detector = TestHelpers + .randomDetector(ImmutableList.of(maxValueFeature()), testIndex, detectionIntervalInMinutes, timeField); + String detectorId = createDetector(detector); + AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( + detectorId, + dateRange, + true, + UNASSIGNED_SEQ_NO, + UNASSIGNED_PRIMARY_TERM, + START_JOB + ); + Client nodeClient = getDataNodeClient(); + if (nodeClient != null) { + AnomalyDetectorJobResponse response = nodeClient.execute(MockAnomalyDetectorJobAction.INSTANCE, request).actionGet(100000); + ADTask adTask = getADTask(response.getId()); + assertNotNull(adTask.getStartedBy()); + assertNotNull(adTask.getUser()); + } + } + + public void testStartHistoricalAnalysisForSingleCategoryHCWithUser() throws IOException, InterruptedException { + ingestTestData(testIndex, startTime, detectionIntervalInMinutes, type + "1", DEFAULT_IP, 2000, false); + ingestTestData(testIndex, startTime, detectionIntervalInMinutes, type + "2", DEFAULT_IP, 2000, false); + AnomalyDetector detector = TestHelpers + .randomDetector( + ImmutableList.of(maxValueFeature()), + testIndex, + detectionIntervalInMinutes, + MockSimpleLog.TIME_FIELD, + ImmutableList.of(categoryField) + ); + String detectorId = createDetector(detector); + AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( + detectorId, + dateRange, + true, + UNASSIGNED_SEQ_NO, + UNASSIGNED_PRIMARY_TERM, + START_JOB + ); + Client nodeClient = getDataNodeClient(); + + if (nodeClient != null) { + AnomalyDetectorJobResponse response = nodeClient.execute(MockAnomalyDetectorJobAction.INSTANCE, request).actionGet(100000); + waitUntil(() -> { + try { + ADTask task = getADTask(response.getId()); + return !TestHelpers.HISTORICAL_ANALYSIS_RUNNING_STATS.contains(task.getState()); + } catch (IOException e) { + return false; + } + }, 20, TimeUnit.SECONDS); + ADTask adTask = getADTask(response.getId()); + assertEquals(ADTaskType.HISTORICAL_HC_DETECTOR.toString(), adTask.getTaskType()); + assertTrue(HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(adTask.getState())); + assertEquals(categoryField, adTask.getDetector().getCategoryFields().get(0)); + + if (ADTaskState.FINISHED.name().equals(adTask.getState())) { + List adTasks = searchADTasks(detectorId, true, 100); + assertEquals(4, adTasks.size()); + List entityTasks = adTasks + .stream() + .filter(task -> ADTaskType.HISTORICAL_HC_ENTITY.name().equals(task.getTaskType())) + .collect(Collectors.toList()); + assertEquals(3, entityTasks.size()); + } + } + } + + public void testStartHistoricalAnalysisForMultiCategoryHCWithUser() throws IOException, InterruptedException { + ingestTestData(testIndex, startTime, detectionIntervalInMinutes, type + "1", DEFAULT_IP, 2000, false); + ingestTestData(testIndex, startTime, detectionIntervalInMinutes, type + "2", DEFAULT_IP, 2000, false); + ingestTestData(testIndex, startTime, detectionIntervalInMinutes, type + "3", "127.0.0.2", 2000, false); + ingestTestData(testIndex, startTime, detectionIntervalInMinutes, type + "4", "127.0.0.2", 2000, false); + + AnomalyDetector detector = TestHelpers + .randomDetector( + ImmutableList.of(maxValueFeature()), + testIndex, + detectionIntervalInMinutes, + MockSimpleLog.TIME_FIELD, + ImmutableList.of(categoryField, ipField) + ); + String detectorId = createDetector(detector); + AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( + detectorId, + dateRange, + true, + UNASSIGNED_SEQ_NO, + UNASSIGNED_PRIMARY_TERM, + START_JOB + ); + Client nodeClient = getDataNodeClient(); + + if (nodeClient != null) { + AnomalyDetectorJobResponse response = nodeClient.execute(MockAnomalyDetectorJobAction.INSTANCE, request).actionGet(100_000); + String taskId = response.getId(); + + waitUntil(() -> { + try { + ADTask task = getADTask(taskId); + return !TestHelpers.HISTORICAL_ANALYSIS_RUNNING_STATS.contains(task.getState()); + } catch (IOException e) { + return false; + } + }, 90, TimeUnit.SECONDS); + ADTask adTask = getADTask(taskId); + assertEquals(ADTaskType.HISTORICAL_HC_DETECTOR.toString(), adTask.getTaskType()); + // Task may fail if memory circuit breaker triggered + assertTrue(HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(adTask.getState())); + assertEquals(categoryField, adTask.getDetector().getCategoryFields().get(0)); + assertEquals(ipField, adTask.getDetector().getCategoryFields().get(1)); + + if (ADTaskState.FINISHED.name().equals(adTask.getState())) { + List adTasks = searchADTasks(detectorId, taskId, true, 100); + assertEquals(5, adTasks.size()); + List entityTasks = adTasks + .stream() + .filter(task -> ADTaskType.HISTORICAL_HC_ENTITY.name().equals(task.getTaskType())) + .collect(Collectors.toList()); + assertEquals(5, entityTasks.size()); + } + } + } + + public void testRunMultipleTasksForHistoricalAnalysis() throws IOException, InterruptedException { + AnomalyDetector detector = TestHelpers + .randomDetector(ImmutableList.of(maxValueFeature()), testIndex, detectionIntervalInMinutes, timeField); + String detectorId = createDetector(detector); + AnomalyDetectorJobRequest request = startDetectorJobRequest(detectorId, dateRange); + AnomalyDetectorJobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); + assertNotNull(response.getId()); + OpenSearchStatusException exception = null; + // Add retry to solve the flaky test + for (int i = 0; i < 10; i++) { + exception = expectThrows( + OpenSearchStatusException.class, + () -> client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000) + ); + if (exception.getMessage().contains(DETECTOR_IS_RUNNING)) { + break; + } else { + logger.error("Unexpected error happened when rerun detector", exception); + } + Thread.sleep(1000); + } + assertNotNull(exception); + assertTrue(exception.getMessage().contains(DETECTOR_IS_RUNNING)); + assertEquals(DETECTOR_IS_RUNNING, exception.getMessage()); + Thread.sleep(20000); + List adTasks = searchADTasks(detectorId, null, 100); + assertEquals(1, adTasks.size()); + assertTrue(HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(adTasks.get(0).getState())); + } + + public void testRaceConditionByStartingMultipleTasks() throws IOException, InterruptedException { + AnomalyDetector detector = TestHelpers + .randomDetector(ImmutableList.of(maxValueFeature()), testIndex, detectionIntervalInMinutes, timeField); + String detectorId = createDetector(detector); + AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( + detectorId, + dateRange, + true, + UNASSIGNED_SEQ_NO, + UNASSIGNED_PRIMARY_TERM, + START_JOB + ); + client().execute(AnomalyDetectorJobAction.INSTANCE, request); + client().execute(AnomalyDetectorJobAction.INSTANCE, request); + + Thread.sleep(5000); + List adTasks = searchADTasks(detectorId, null, 100); + + assertEquals(1, adTasks.size()); + assertTrue(adTasks.get(0).getLatest()); + assertNotEquals(ADTaskState.FAILED.name(), adTasks.get(0).getState()); + } + + // TODO: fix this flaky test case + @Ignore + public void testCleanOldTaskDocs() throws InterruptedException, IOException { + AnomalyDetector detector = TestHelpers + .randomDetector(ImmutableList.of(maxValueFeature()), testIndex, detectionIntervalInMinutes, timeField); + String detectorId = createDetector(detector); + + createDetectionStateIndex(); + List states = ImmutableList.of(ADTaskState.FAILED, ADTaskState.FINISHED, ADTaskState.STOPPED); + for (ADTaskState state : states) { + ADTask task = randomADTask(randomAlphaOfLength(5), detector, detectorId, dateRange, state); + createADTask(task); + } + long count = countDocs(ADCommonName.DETECTION_STATE_INDEX); + assertEquals(states.size(), count); + + AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( + detectorId, + dateRange, + true, + randomLong(), + randomLong(), + START_JOB + ); + + AtomicReference response = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + Thread.sleep(2000); + client().execute(AnomalyDetectorJobAction.INSTANCE, request, ActionListener.wrap(r -> { + latch.countDown(); + response.set(r); + }, e -> { latch.countDown(); })); + latch.await(); + Thread.sleep(10000); + count = countDetectorDocs(detectorId); + // we have one latest task, so total count should add 1 + assertEquals(maxOldAdTaskDocsPerDetector + 1, count); + } + + @After + @Override + public void tearDown() throws Exception { + super.tearDown(); + // delete index will clear search context, this can avoid in-flight contexts error + deleteIndexIfExists(CommonName.CONFIG_INDEX); + deleteIndexIfExists(ADCommonName.DETECTION_STATE_INDEX); + } + + public void testStartRealtimeDetector() throws IOException { + List realtimeResult = startRealtimeDetector(); + String detectorId = realtimeResult.get(0); + String jobId = realtimeResult.get(1); + GetResponse jobDoc = getDoc(CommonName.JOB_INDEX, detectorId); + AnomalyDetectorJob job = toADJob(jobDoc); + assertTrue(job.isEnabled()); + assertEquals(detectorId, job.getName()); + + List adTasks = searchADTasks(detectorId, true, 10); + assertEquals(1, adTasks.size()); + assertEquals(ADTaskType.REALTIME_SINGLE_ENTITY.name(), adTasks.get(0).getTaskType()); + assertNotEquals(jobId, adTasks.get(0).getTaskId()); + } + + private List startRealtimeDetector() throws IOException { + AnomalyDetector detector = TestHelpers + .randomDetector(ImmutableList.of(maxValueFeature()), testIndex, detectionIntervalInMinutes, timeField); + String detectorId = createDetector(detector); + AnomalyDetectorJobRequest request = startDetectorJobRequest(detectorId, null); + AnomalyDetectorJobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); + String jobId = response.getId(); + assertEquals(detectorId, jobId); + return ImmutableList.of(detectorId, jobId); + } + + public void testRealtimeDetectorWithoutFeature() throws IOException { + AnomalyDetector detector = TestHelpers.randomDetector(ImmutableList.of(), testIndex, detectionIntervalInMinutes, timeField); + testInvalidDetector(detector, "Can't start detector job as no features configured"); + } + + public void testHistoricalDetectorWithoutFeature() throws IOException { + AnomalyDetector detector = TestHelpers.randomDetector(ImmutableList.of(), testIndex, detectionIntervalInMinutes, timeField); + testInvalidDetector(detector, "Can't start detector job as no features configured"); + } + + public void testRealtimeDetectorWithoutEnabledFeature() throws IOException { + AnomalyDetector detector = TestHelpers + .randomDetector(ImmutableList.of(TestHelpers.randomFeature(false)), testIndex, detectionIntervalInMinutes, timeField); + testInvalidDetector(detector, "Can't start detector job as no enabled features configured"); + } + + public void testHistoricalDetectorWithoutEnabledFeature() throws IOException { + AnomalyDetector detector = TestHelpers + .randomDetector(ImmutableList.of(TestHelpers.randomFeature(false)), testIndex, detectionIntervalInMinutes, timeField); + testInvalidDetector(detector, "Can't start detector job as no enabled features configured"); + } + + private void testInvalidDetector(AnomalyDetector detector, String error) throws IOException { + String detectorId = createDetector(detector); + AnomalyDetectorJobRequest request = startDetectorJobRequest(detectorId, dateRange); + OpenSearchStatusException exception = expectThrows( + OpenSearchStatusException.class, + () -> client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000) + ); + assertEquals(error, exception.getMessage()); + } + + private AnomalyDetectorJobRequest startDetectorJobRequest(String detectorId, DateRange dateRange) { + return new AnomalyDetectorJobRequest(detectorId, dateRange, false, UNASSIGNED_SEQ_NO, UNASSIGNED_PRIMARY_TERM, START_JOB); + } + + private AnomalyDetectorJobRequest stopDetectorJobRequest(String detectorId, boolean historical) { + return new AnomalyDetectorJobRequest(detectorId, null, historical, UNASSIGNED_SEQ_NO, UNASSIGNED_PRIMARY_TERM, STOP_JOB); + } + + public void testStopRealtimeDetector() throws IOException { + List realtimeResult = startRealtimeDetector(); + String detectorId = realtimeResult.get(0); + String jobId = realtimeResult.get(1); + + AnomalyDetectorJobRequest request = stopDetectorJobRequest(detectorId, false); + client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); + GetResponse doc = getDoc(CommonName.JOB_INDEX, detectorId); + AnomalyDetectorJob job = toADJob(doc); + assertFalse(job.isEnabled()); + assertEquals(detectorId, job.getName()); + + List adTasks = searchADTasks(detectorId, true, 10); + assertEquals(1, adTasks.size()); + assertEquals(ADTaskType.REALTIME_SINGLE_ENTITY.name(), adTasks.get(0).getTaskType()); + assertNotEquals(jobId, adTasks.get(0).getTaskId()); + assertEquals(ADTaskState.STOPPED.name(), adTasks.get(0).getState()); + } + + public void testStopHistoricalDetector() throws IOException, InterruptedException { + updateTransientSettings(ImmutableMap.of(BATCH_TASK_PIECE_INTERVAL_SECONDS.getKey(), 5)); + ADTask adTask = startHistoricalAnalysis(startTime, endTime); + assertEquals(ADTaskState.INIT.name(), adTask.getState()); + assertNull(adTask.getStartedBy()); + assertNull(adTask.getUser()); + waitUntil(() -> { + try { + ADTask task = getADTask(adTask.getTaskId()); + boolean taskRunning = TestHelpers.HISTORICAL_ANALYSIS_RUNNING_STATS.contains(task.getState()); + if (taskRunning) { + // It's possible that the task not started on worker node yet. Recancel it to make sure + // task cancelled. + AnomalyDetectorJobRequest request = stopDetectorJobRequest(adTask.getId(), true); + client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); + } + return !taskRunning; + } catch (Exception e) { + return false; + } + }, 20, TimeUnit.SECONDS); + ADTask stoppedTask = getADTask(adTask.getTaskId()); + assertEquals(ADTaskState.STOPPED.name(), stoppedTask.getState()); + assertEquals(0, getExecutingADTask()); + } + + public void testProfileHistoricalDetector() throws IOException, InterruptedException { + ADTask adTask = startHistoricalAnalysis(startTime, endTime); + GetAnomalyDetectorRequest request = taskProfileRequest(adTask.getId()); + GetAnomalyDetectorResponse response = client().execute(GetAnomalyDetectorAction.INSTANCE, request).actionGet(10000); + assertTrue(response.getDetectorProfile().getAdTaskProfile() != null); + + ADTask finishedTask = getADTask(adTask.getTaskId()); + int i = 0; + while (TestHelpers.HISTORICAL_ANALYSIS_RUNNING_STATS.contains(finishedTask.getState()) && i < 10) { + finishedTask = getADTask(adTask.getTaskId()); + Thread.sleep(2000); + i++; + } + assertTrue(HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(finishedTask.getState())); + + response = client().execute(GetAnomalyDetectorAction.INSTANCE, request).actionGet(10000); + assertNull(response.getDetectorProfile().getAdTaskProfile().getNodeId()); + ADTask profileAdTask = response.getDetectorProfile().getAdTaskProfile().getAdTask(); + assertEquals(finishedTask.getTaskId(), profileAdTask.getTaskId()); + assertEquals(finishedTask.getId(), profileAdTask.getId()); + assertEquals(finishedTask.getDetector(), profileAdTask.getDetector()); + assertEquals(finishedTask.getState(), profileAdTask.getState()); + } + + public void testProfileWithMultipleRunningTask() throws IOException { + ADTask adTask1 = startHistoricalAnalysis(startTime, endTime); + ADTask adTask2 = startHistoricalAnalysis(startTime, endTime); + + GetAnomalyDetectorRequest request1 = taskProfileRequest(adTask1.getId()); + GetAnomalyDetectorRequest request2 = taskProfileRequest(adTask2.getId()); + GetAnomalyDetectorResponse response1 = client().execute(GetAnomalyDetectorAction.INSTANCE, request1).actionGet(10000); + GetAnomalyDetectorResponse response2 = client().execute(GetAnomalyDetectorAction.INSTANCE, request2).actionGet(10000); + ADTaskProfile taskProfile1 = response1.getDetectorProfile().getAdTaskProfile(); + ADTaskProfile taskProfile2 = response2.getDetectorProfile().getAdTaskProfile(); + assertNotNull(taskProfile1.getNodeId()); + assertNotNull(taskProfile2.getNodeId()); + assertNotEquals(taskProfile1.getNodeId(), taskProfile2.getNodeId()); + } + + private GetAnomalyDetectorRequest taskProfileRequest(String detectorId) throws IOException { + return new GetAnomalyDetectorRequest(detectorId, Versions.MATCH_ANY, false, false, "", PROFILE, true, null); + } + + private long getExecutingADTask() { + ADStatsRequest adStatsRequest = new ADStatsRequest(getDataNodesArray()); + Set validStats = ImmutableSet.of(StatNames.AD_EXECUTING_BATCH_TASK_COUNT.getName()); + adStatsRequest.addAll(validStats); + StatsAnomalyDetectorResponse statsResponse = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5000); + AtomicLong totalExecutingTask = new AtomicLong(0); + statsResponse + .getAdStatsResponse() + .getADStatsNodesResponse() + .getNodes() + .forEach( + node -> { totalExecutingTask.getAndAdd((Long) node.getStatsMap().get(StatNames.AD_EXECUTING_BATCH_TASK_COUNT.getName())); } + ); + return totalExecutingTask.get(); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java b/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java index 38fdbddc0..6b671d6e2 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java @@ -94,17 +94,17 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.NotSerializableExceptionWrapper; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; import org.opensearch.common.transport.TransportAddress; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.index.Index; import org.opensearch.index.IndexNotFoundException; -import org.opensearch.index.shard.ShardId; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java-e b/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java-e new file mode 100644 index 000000000..5d0a25a1a --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java-e @@ -0,0 +1,1871 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.hamcrest.CoreMatchers.startsWith; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.anyDouble; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.same; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; +import static org.opensearch.timeseries.TestHelpers.createIndexBlockedState; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.function.Function; + +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.mockito.ArgumentCaptor; +import org.opensearch.OpenSearchTimeoutException; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.common.exception.JsonPathNotFoundException; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.feature.SinglePointFeatures; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.SingleStreamModelIdMapper; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.DetectorInternalState; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.stats.ADStat; +import org.opensearch.ad.stats.ADStats; +import org.opensearch.ad.stats.suppliers.CounterSupplier; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.block.ClusterBlocks; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Strings; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.transport.TransportAddress; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.index.Index; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.InternalFailure; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.transport.NodeNotConnectedException; +import org.opensearch.transport.RemoteTransportException; +import org.opensearch.transport.Transport; +import org.opensearch.transport.TransportException; +import org.opensearch.transport.TransportInterceptor; +import org.opensearch.transport.TransportRequest; +import org.opensearch.transport.TransportRequestOptions; +import org.opensearch.transport.TransportResponse; +import org.opensearch.transport.TransportResponseHandler; +import org.opensearch.transport.TransportService; + +import test.org.opensearch.ad.util.JsonDeserializer; + +import com.google.gson.JsonElement; + +public class AnomalyResultTests extends AbstractTimeSeriesTest { + private Settings settings; + private TransportService transportService; + private ClusterService clusterService; + private NodeStateManager stateManager; + private FeatureManager featureQuery; + private ModelManager normalModelManager; + private Client client; + private SecurityClientUtil clientUtil; + private AnomalyDetector detector; + private HashRing hashRing; + private IndexNameExpressionResolver indexNameResolver; + private String thresholdModelID; + private String adID; + private String featureId; + private String featureName; + private ADCircuitBreakerService adCircuitBreakerService; + private ADStats adStats; + private double confidence; + private double anomalyGrade; + private ADTaskManager adTaskManager; + + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(AnomalyResultTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + + @SuppressWarnings("unchecked") + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + super.setUpLog4jForJUnit(AnomalyResultTransportAction.class); + + setupTestNodes(AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, AnomalyDetectorSettings.PAGE_SIZE); + + transportService = testNodes[0].transportService; + clusterService = testNodes[0].clusterService; + settings = clusterService.getSettings(); + + stateManager = mock(NodeStateManager.class); + when(stateManager.isMuted(any(String.class), any(String.class))).thenReturn(false); + when(stateManager.markColdStartRunning(anyString())).thenReturn(() -> {}); + + detector = mock(AnomalyDetector.class); + featureId = "xyz"; + // we have one feature + when(detector.getEnabledFeatureIds()).thenReturn(Collections.singletonList(featureId)); + featureName = "abc"; + when(detector.getEnabledFeatureNames()).thenReturn(Collections.singletonList(featureName)); + List userIndex = new ArrayList<>(); + userIndex.add("test*"); + when(detector.getIndices()).thenReturn(userIndex); + adID = "123"; + when(detector.getId()).thenReturn(adID); + when(detector.getCategoryFields()).thenReturn(null); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(stateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + when(detector.getIntervalInMinutes()).thenReturn(1L); + + hashRing = mock(HashRing.class); + Optional localNode = Optional.of(clusterService.state().nodes().getLocalNode()); + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(localNode); + doReturn(localNode).when(hashRing).getNodeByAddress(any()); + featureQuery = mock(FeatureManager.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(new SinglePointFeatures(Optional.of(new double[] { 0.0d }), Optional.of(new double[] { 0 }))); + return null; + }).when(featureQuery).getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), any(ActionListener.class)); + + double rcfScore = 0.2; + confidence = 0.91; + anomalyGrade = 0.5; + normalModelManager = mock(ModelManager.class); + long totalUpdates = 1440; + int relativeIndex = 0; + double[] currentTimeAttribution = new double[] { 0.5, 0.5 }; + double[] pastValues = new double[] { 123, 456 }; + double[][] expectedValuesList = new double[][] { new double[] { 789, 12 } }; + double[] likelihood = new double[] { 1 }; + double threshold = 1.1d; + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener + .onResponse( + new ThresholdingResult( + anomalyGrade, + confidence, + rcfScore, + totalUpdates, + relativeIndex, + currentTimeAttribution, + pastValues, + expectedValuesList, + likelihood, + threshold, + 30 + ) + ); + return null; + }).when(normalModelManager).getTRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(new ThresholdingResult(0, 1.0d, rcfScore)); + return null; + }).when(normalModelManager).getThresholdingResult(any(String.class), any(String.class), anyDouble(), any(ActionListener.class)); + + thresholdModelID = SingleStreamModelIdMapper.getThresholdModelId(adID); // "123-threshold"; + // when(normalModelPartitioner.getThresholdModelId(any(String.class))).thenReturn(thresholdModelID); + adCircuitBreakerService = mock(ADCircuitBreakerService.class); + when(adCircuitBreakerService.isOpen()).thenReturn(false); + + ThreadPool threadPool = mock(ThreadPool.class); + client = mock(Client.class); + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + when(client.threadPool()).thenReturn(threadPool); + when(client.threadPool().getThreadContext()).thenReturn(threadContext); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length >= 2 + ); + + IndexRequest request = null; + ActionListener listener = null; + if (args[0] instanceof IndexRequest) { + request = (IndexRequest) args[0]; + } + if (args[1] instanceof ActionListener) { + listener = (ActionListener) args[1]; + } + + assertTrue(request != null && listener != null); + ShardId shardId = new ShardId(new Index(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, randomAlphaOfLength(10)), 0); + listener.onResponse(new IndexResponse(shardId, request.id(), 1, 1, 1, true)); + + return null; + }).when(client).index(any(), any()); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, settings); + + indexNameResolver = new IndexNameExpressionResolver(new ThreadContext(Settings.EMPTY)); + + Map> statsMap = new HashMap>() { + { + put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + } + }; + + adStats = new ADStats(statsMap); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + GetRequest request = (GetRequest) args[0]; + ActionListener listener = (ActionListener) args[1]; + + if (request.index().equals(ADCommonName.DETECTION_STATE_INDEX)) { + + DetectorInternalState.Builder result = new DetectorInternalState.Builder().lastUpdateTime(Instant.now()); + + listener.onResponse(TestHelpers.createGetResponse(result.build(), detector.getId(), ADCommonName.DETECTION_STATE_INDEX)); + + } + + return null; + }).when(client).get(any(), any()); + + adTaskManager = mock(ADTaskManager.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }) + .when(adTaskManager) + .initRealtimeTaskCacheAndCleanupStaleCache( + anyString(), + any(AnomalyDetector.class), + any(TransportService.class), + any(ActionListener.class) + ); + } + + @Override + @After + public final void tearDown() throws Exception { + tearDownTestNodes(); + client = null; + super.tearDownLog4jForJUnit(); + super.tearDown(); + } + + private Throwable assertException(PlainActionFuture listener, Class exceptionType) { + return expectThrows(exceptionType, () -> listener.actionGet()); + } + + public void testNormal() throws IOException { + + // These constructors register handler in transport service + new RCFResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + normalModelManager, + adCircuitBreakerService, + hashRing, + adStats + ); + new ThresholdResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, normalModelManager); + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + threadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertAnomalyResultResponse(response, anomalyGrade, confidence, 0d); + } + + private void assertAnomalyResultResponse(AnomalyResultResponse response, double anomalyGrade, double confidence, double featureData) { + assertEquals(anomalyGrade, response.getAnomalyGrade(), 0.001); + assertEquals(confidence, response.getConfidence(), 0.001); + assertEquals(1, response.getFeatures().size()); + FeatureData responseFeature = response.getFeatures().get(0); + assertEquals(featureData, responseFeature.getData(), 0.001); + assertEquals(featureId, responseFeature.getFeatureId()); + assertEquals(featureName, responseFeature.getFeatureName()); + } + + /** + * Create handler that would return a failure + * @param handler callback handler + * @return handler that would return a failure + */ + private TransportResponseHandler rcfFailureHandler( + TransportResponseHandler handler, + Exception exception + ) { + return new TransportResponseHandler() { + @Override + public T read(StreamInput in) throws IOException { + return handler.read(in); + } + + @Override + public void handleResponse(T response) { + handler.handleException(new RemoteTransportException("test", exception)); + } + + @Override + public void handleException(TransportException exp) { + handler.handleException(exp); + } + + @Override + public String executor() { + return handler.executor(); + } + }; + } + + public void noModelExceptionTemplate( + Exception thrownException, + String adID, + Class expectedExceptionType, + String error + ) { + + TransportInterceptor failureTransportInterceptor = new TransportInterceptor() { + @Override + public AsyncSender interceptSender(AsyncSender sender) { + return new AsyncSender() { + @Override + public void sendRequest( + Transport.Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + if (RCFResultAction.NAME.equals(action)) { + sender.sendRequest(connection, action, request, options, rcfFailureHandler(handler, thrownException)); + } else { + sender.sendRequest(connection, action, request, options, handler); + } + } + }; + } + }; + + // need to close nodes created in the setUp nodes and create new nodes + // for the failure interceptor. Otherwise, we will get thread leak error. + tearDownTestNodes(); + setupTestNodes( + failureTransportInterceptor, + Settings.EMPTY, + AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, + AnomalyDetectorSettings.PAGE_SIZE + ); + + // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor + Optional discoveryNode = Optional.of(testNodes[1].discoveryNode()); + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(discoveryNode); + when(hashRing.getNodeByAddress(any(TransportAddress.class))).thenReturn(discoveryNode); + // register handler on testNodes[1] + new RCFResultTransportAction( + new ActionFilters(Collections.emptySet()), + testNodes[1].transportService, + normalModelManager, + adCircuitBreakerService, + hashRing, + adStats + ); + + TransportService realTransportService = testNodes[0].transportService; + ClusterService realClusterService = testNodes[0].clusterService; + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + realTransportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + realClusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + threadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + Throwable exception = assertException(listener, expectedExceptionType); + assertTrue("actual message: " + exception.getMessage(), exception.getMessage().contains(error)); + } + + public void noModelExceptionTemplate(Exception exception, String adID, String error) { + noModelExceptionTemplate(exception, adID, exception.getClass(), error); + } + + @SuppressWarnings("unchecked") + public void testInsufficientCapacityExceptionDuringColdStart() { + + ModelManager rcfManager = mock(ModelManager.class); + doThrow(ResourceNotFoundException.class) + .when(rcfManager) + .getTRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); + + when(stateManager.fetchExceptionAndClear(any(String.class))) + .thenReturn(Optional.of(new LimitExceededException(adID, CommonMessages.MEMORY_LIMIT_EXCEEDED_ERR_MSG))); + + // These constructors register handler in transport service + new RCFResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + rcfManager, + adCircuitBreakerService, + hashRing, + adStats + ); + new ThresholdResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, normalModelManager); + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + threadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + assertException(listener, LimitExceededException.class); + } + + @SuppressWarnings("unchecked") + public void testInsufficientCapacityExceptionDuringRestoringModel() { + + ModelManager rcfManager = mock(ModelManager.class); + doThrow(new NotSerializableExceptionWrapper(new LimitExceededException(adID, CommonMessages.MEMORY_LIMIT_EXCEEDED_ERR_MSG))) + .when(rcfManager) + .getTRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); + + // These constructors register handler in transport service + new RCFResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + rcfManager, + adCircuitBreakerService, + hashRing, + adStats + ); + new ThresholdResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, normalModelManager); + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + threadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + assertException(listener, LimitExceededException.class); + } + + private TransportResponseHandler rcfResponseHandler(TransportResponseHandler handler) { + return new TransportResponseHandler() { + @Override + public T read(StreamInput in) throws IOException { + return handler.read(in); + } + + @Override + @SuppressWarnings("unchecked") + public void handleResponse(T response) { + handler + .handleResponse( + (T) new RCFResultResponse( + 1, + 1, + 100, + new double[0], + randomInt(), + randomDouble(), + Version.CURRENT, + randomIntBetween(-3, 0), + new double[] { randomDouble(), randomDouble() }, + new double[][] { new double[] { randomDouble(), randomDouble() } }, + new double[] { randomDouble() }, + randomDoubleBetween(1.1, 10.0, true) + ) + ); + } + + @Override + public void handleException(TransportException exp) { + handler.handleException(exp); + } + + @Override + public String executor() { + return handler.executor(); + } + }; + } + + public void thresholdExceptionTestTemplate( + Exception thrownException, + String adID, + Class expectedExceptionType, + String error + ) { + + TransportInterceptor failureTransportInterceptor = new TransportInterceptor() { + @Override + public AsyncSender interceptSender(AsyncSender sender) { + return new AsyncSender() { + @Override + public void sendRequest( + Transport.Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + if (ThresholdResultAction.NAME.equals(action)) { + sender.sendRequest(connection, action, request, options, rcfFailureHandler(handler, thrownException)); + } else if (RCFResultAction.NAME.equals(action)) { + sender.sendRequest(connection, action, request, options, rcfResponseHandler(handler)); + } else { + sender.sendRequest(connection, action, request, options, handler); + } + } + }; + } + }; + + // need to close nodes created in the setUp nodes and create new nodes + // for the failure interceptor. Otherwise, we will get thread leak error. + tearDownTestNodes(); + setupTestNodes( + failureTransportInterceptor, + Settings.EMPTY, + AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, + AnomalyDetectorSettings.PAGE_SIZE + ); + + // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor + Optional discoveryNode = Optional.of(testNodes[1].discoveryNode()); + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(discoveryNode); + when(hashRing.getNodeByAddress(any(TransportAddress.class))).thenReturn(discoveryNode); + // register handlers on testNodes[1] + ActionFilters actionFilters = new ActionFilters(Collections.emptySet()); + new RCFResultTransportAction( + actionFilters, + testNodes[1].transportService, + normalModelManager, + adCircuitBreakerService, + hashRing, + adStats + ); + new ThresholdResultTransportAction(actionFilters, testNodes[1].transportService, normalModelManager); + + TransportService realTransportService = testNodes[0].transportService; + ClusterService realClusterService = testNodes[0].clusterService; + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + realTransportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + realClusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + threadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + Throwable exception = assertException(listener, expectedExceptionType); + assertTrue("actual message: " + exception.getMessage(), exception.getMessage().contains(error)); + } + + public void testCircuitBreaker() { + + ADCircuitBreakerService breakerService = mock(ADCircuitBreakerService.class); + when(breakerService.isOpen()).thenReturn(true); + + // These constructors register handler in transport service + new RCFResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + normalModelManager, + breakerService, + hashRing, + adStats + ); + new ThresholdResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, normalModelManager); + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + breakerService, + adStats, + threadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + assertException(listener, LimitExceededException.class); + } + + /** + * Test whether we can handle NodeNotConnectedException when sending requests to + * remote nodes. + * + * @param isRCF whether RCF model node throws node connection + * exception or not + * @param temporary whether node has only temporary connection issue. If + * yes, we should not trigger hash ring rebuilding. + * @param numberOfBuildCall the number of expected hash ring build call + */ + private void nodeNotConnectedExceptionTemplate(boolean isRCF, boolean temporary, int numberOfBuildCall) { + ClusterService hackedClusterService = spy(clusterService); + + TransportService exceptionTransportService = spy(transportService); + + DiscoveryNode rcfNode = clusterService.state().nodes().getLocalNode(); + DiscoveryNode thresholdNode = testNodes[1].discoveryNode(); + + if (isRCF) { + doThrow(new NodeNotConnectedException(rcfNode, "rcf node not connected")) + .when(exceptionTransportService) + .getConnection(same(rcfNode)); + } else { + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(eq(thresholdModelID))).thenReturn(Optional.of(thresholdNode)); + when(hashRing.getNodeByAddress(any())).thenReturn(Optional.of(thresholdNode)); + doThrow(new NodeNotConnectedException(rcfNode, "rcf node not connected")) + .when(exceptionTransportService) + .getConnection(same(thresholdNode)); + } + + if (!temporary) { + when(hackedClusterService.state()).thenReturn(ClusterState.builder(new ClusterName("test")).build()); + } + + // These constructors register handler in transport service + new RCFResultTransportAction( + new ActionFilters(Collections.emptySet()), + exceptionTransportService, + normalModelManager, + adCircuitBreakerService, + hashRing, + adStats + ); + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + exceptionTransportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + hackedClusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + threadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + assertException(listener, TimeSeriesException.class); + + if (!temporary) { + verify(hashRing, times(numberOfBuildCall)).buildCirclesForRealtimeAD(); + verify(stateManager, never()).addPressure(any(String.class), any(String.class)); + } else { + verify(hashRing, never()).buildCirclesForRealtimeAD(); + verify(stateManager, times(numberOfBuildCall)).addPressure(any(String.class), any(String.class)); + } + } + + public void testRCFNodeNotConnectedException() { + // we expect one hashRing.build calls since we have one RCF model partitions + nodeNotConnectedExceptionTemplate(true, false, 1); + } + + public void testTemporaryRCFNodeNotConnectedException() { + // we expect one hashRing.build calls since we have one RCF model partitions + nodeNotConnectedExceptionTemplate(true, true, 1); + } + + @SuppressWarnings("unchecked") + public void testMute() { + NodeStateManager muteStateManager = mock(NodeStateManager.class); + when(muteStateManager.isMuted(any(String.class), any(String.class))).thenReturn(true); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(muteStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + muteStateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + threadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + Throwable exception = assertException(listener, TimeSeriesException.class); + assertThat(exception.getMessage(), containsString(AnomalyResultTransportAction.NODE_UNRESPONSIVE_ERR_MSG)); + } + + public void alertingRequestTemplate(boolean anomalyResultIndexExists) throws IOException { + // These constructors register handler in transport service + new RCFResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + normalModelManager, + adCircuitBreakerService, + hashRing, + adStats + ); + Optional localNode = Optional.of(clusterService.state().nodes().getLocalNode()); + + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(localNode); + doReturn(localNode).when(hashRing).getNodeByAddress(any()); + new ThresholdResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, normalModelManager); + + new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + threadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + + TransportRequestOptions option = TransportRequestOptions + .builder() + .withType(TransportRequestOptions.Type.STATE) + .withTimeout(6000) + .build(); + + transportService + .sendRequest( + clusterService.state().nodes().getLocalNode(), + AnomalyResultAction.NAME, + new AnomalyResultRequest(adID, 100, 200), + option, + new TransportResponseHandler() { + + @Override + public AnomalyResultResponse read(StreamInput in) throws IOException { + return new AnomalyResultResponse(in); + } + + @Override + public void handleResponse(AnomalyResultResponse response) { + assertAnomalyResultResponse(response, 0, 1, 0d); + } + + @Override + public void handleException(TransportException exp) { + assertThat(exp, is(nullValue())); + } + + @Override + public String executor() { + return ThreadPool.Names.GENERIC; + } + } + ); + } + + public void testSerialzationResponse() throws IOException { + AnomalyResultResponse response = new AnomalyResultResponse( + 4d, + 0.993, + 1.01, + Collections.singletonList(new FeatureData(featureId, featureName, 0d)), + randomAlphaOfLength(4), + randomLong(), + randomLong(), + randomBoolean(), + randomInt(), + new double[] { randomDoubleBetween(0, 1.0, true), randomDoubleBetween(0, 1.0, true) }, + new double[] { randomDouble(), randomDouble() }, + new double[][] { new double[] { randomDouble(), randomDouble() } }, + new double[] { randomDouble() }, + randomDoubleBetween(1.1, 10.0, true) + ); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + AnomalyResultResponse readResponse = AnomalyResultAction.INSTANCE.getResponseReader().read(streamInput); + assertAnomalyResultResponse(readResponse, readResponse.getAnomalyGrade(), readResponse.getConfidence(), 0d); + } + + public void testJsonResponse() throws IOException, JsonPathNotFoundException { + AnomalyResultResponse response = new AnomalyResultResponse( + 4d, + 0.993, + 1.01, + Collections.singletonList(new FeatureData(featureId, featureName, 0d)), + randomAlphaOfLength(4), + randomLong(), + randomLong(), + randomBoolean(), + randomInt(), + new double[] { randomDoubleBetween(0, 1.0, true), randomDoubleBetween(0, 1.0, true) }, + new double[] { randomDouble(), randomDouble() }, + new double[][] { new double[] { randomDouble(), randomDouble() } }, + new double[] { randomDouble() }, + randomDoubleBetween(1.1, 10.0, true) + ); + XContentBuilder builder = jsonBuilder(); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String json = Strings.toString(builder); + Function function = (s) -> { + try { + String featureId = JsonDeserializer.getTextValue(s, FeatureData.FEATURE_ID_FIELD); + String featureName = JsonDeserializer.getTextValue(s, FeatureData.FEATURE_NAME_FIELD); + double featureValue = JsonDeserializer.getDoubleValue(s, FeatureData.DATA_FIELD); + return new FeatureData(featureId, featureName, featureValue); + } catch (Exception e) { + Assert.fail(e.getMessage()); + } + return null; + }; + + AnomalyResultResponse readResponse = new AnomalyResultResponse( + JsonDeserializer.getDoubleValue(json, AnomalyResultResponse.ANOMALY_GRADE_JSON_KEY), + JsonDeserializer.getDoubleValue(json, AnomalyResultResponse.CONFIDENCE_JSON_KEY), + JsonDeserializer.getDoubleValue(json, AnomalyResultResponse.ANOMALY_SCORE_JSON_KEY), + JsonDeserializer.getListValue(json, function, AnomalyResultResponse.FEATURES_JSON_KEY), + randomAlphaOfLength(4), + randomLong(), + randomLong(), + randomBoolean(), + randomInt(), + new double[] { randomDoubleBetween(0, 1.0, true), randomDoubleBetween(0, 1.0, true) }, + new double[] { randomDouble(), randomDouble() }, + new double[][] { new double[] { randomDouble(), randomDouble() } }, + new double[] { randomDouble() }, + randomDoubleBetween(1.1, 10.0, true) + ); + assertAnomalyResultResponse(readResponse, readResponse.getAnomalyGrade(), readResponse.getConfidence(), 0d); + } + + public void testSerialzationRequest() throws IOException { + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + BytesStreamOutput output = new BytesStreamOutput(); + request.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + AnomalyResultRequest readRequest = new AnomalyResultRequest(streamInput); + assertThat(request.getAdID(), equalTo(readRequest.getAdID())); + assertThat(request.getStart(), equalTo(readRequest.getStart())); + assertThat(request.getEnd(), equalTo(readRequest.getEnd())); + } + + public void testJsonRequest() throws IOException, JsonPathNotFoundException { + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + XContentBuilder builder = jsonBuilder(); + request.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String json = Strings.toString(builder); + assertEquals(JsonDeserializer.getTextValue(json, ADCommonName.ID_JSON_KEY), request.getAdID()); + assertEquals(JsonDeserializer.getLongValue(json, CommonName.START_JSON_KEY), request.getStart()); + assertEquals(JsonDeserializer.getLongValue(json, CommonName.END_JSON_KEY), request.getEnd()); + } + + public void testEmptyID() { + ActionRequestValidationException e = new AnomalyResultRequest("", 100, 200).validate(); + assertThat(e.validationErrors(), hasItem(ADCommonMessages.AD_ID_MISSING_MSG)); + } + + public void testZeroStartTime() { + ActionRequestValidationException e = new AnomalyResultRequest(adID, 0, 200).validate(); + assertThat(e.validationErrors(), hasItem(startsWith(CommonMessages.INVALID_TIMESTAMP_ERR_MSG))); + } + + public void testNegativeEndTime() { + ActionRequestValidationException e = new AnomalyResultRequest(adID, 0, -200).validate(); + assertThat(e.validationErrors(), hasItem(startsWith(CommonMessages.INVALID_TIMESTAMP_ERR_MSG))); + } + + public void testNegativeTime() { + ActionRequestValidationException e = new AnomalyResultRequest(adID, 10, -200).validate(); + assertThat(e.validationErrors(), hasItem(startsWith(CommonMessages.INVALID_TIMESTAMP_ERR_MSG))); + } + + // no exception should be thrown + @SuppressWarnings("unchecked") + public void testOnFailureNull() throws IOException { + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + threadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + AnomalyResultTransportAction.RCFActionListener listener = action.new RCFActionListener( + null, null, null, null, mock(ActionListener.class), null, null + ); + listener.onFailure(null); + } + + static class ColdStartConfig { + boolean coldStartRunning = false; + Exception getCheckpointException = null; + + ColdStartConfig(Builder builder) { + this.coldStartRunning = builder.coldStartRunning; + this.getCheckpointException = builder.getCheckpointException; + } + + static class Builder { + boolean coldStartRunning = false; + Exception getCheckpointException = null; + + Builder coldStartRunning(boolean coldStartRunning) { + this.coldStartRunning = coldStartRunning; + return this; + } + + Builder getCheckpointException(Exception exception) { + this.getCheckpointException = exception; + return this; + } + + public ColdStartConfig build() { + return new ColdStartConfig(this); + } + } + } + + @SuppressWarnings("unchecked") + private void setUpColdStart(ThreadPool mockThreadPool, ColdStartConfig config) { + SinglePointFeatures mockSinglePoint = mock(SinglePointFeatures.class); + + when(mockSinglePoint.getProcessedFeatures()).thenReturn(Optional.empty()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockSinglePoint); + return null; + }).when(featureQuery).getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + if (config.getCheckpointException == null) { + listener.onResponse(Boolean.FALSE); + } else { + listener.onFailure(config.getCheckpointException); + } + + return null; + }).when(stateManager).getDetectorCheckpoint(any(String.class), any(ActionListener.class)); + + when(stateManager.isColdStartRunning(any(String.class))).thenReturn(config.coldStartRunning); + + setUpADThreadPool(mockThreadPool); + } + + @SuppressWarnings("unchecked") + public void testColdStartNoTrainingData() throws Exception { + ThreadPool mockThreadPool = mock(ThreadPool.class); + setUpColdStart(mockThreadPool, new ColdStartConfig.Builder().coldStartRunning(false).build()); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.empty()); + return null; + }).when(featureQuery).getColdStartData(any(AnomalyDetector.class), any(ActionListener.class)); + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + mockThreadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + verify(stateManager, times(1)).setException(eq(adID), any(EndRunException.class)); + verify(stateManager, times(1)).markColdStartRunning(eq(adID)); + } + + @SuppressWarnings("unchecked") + public void testConcurrentColdStart() throws Exception { + ThreadPool mockThreadPool = mock(ThreadPool.class); + setUpColdStart(mockThreadPool, new ColdStartConfig.Builder().coldStartRunning(true).build()); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.empty()); + return null; + }).when(featureQuery).getColdStartData(any(AnomalyDetector.class), any(ActionListener.class)); + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + mockThreadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + verify(stateManager, never()).setException(eq(adID), any(EndRunException.class)); + verify(stateManager, never()).markColdStartRunning(eq(adID)); + } + + @SuppressWarnings("unchecked") + public void testColdStartTimeoutPutCheckpoint() throws Exception { + ThreadPool mockThreadPool = mock(ThreadPool.class); + setUpColdStart(mockThreadPool, new ColdStartConfig.Builder().coldStartRunning(false).build()); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(new double[][] { { 1.0 } })); + return null; + }).when(featureQuery).getColdStartData(any(AnomalyDetector.class), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onFailure(new OpenSearchTimeoutException("")); + return null; + }).when(normalModelManager).trainModel(any(AnomalyDetector.class), any(double[][].class), any(ActionListener.class)); + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + mockThreadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + verify(stateManager, times(1)).setException(eq(adID), any(InternalFailure.class)); + verify(stateManager, times(1)).markColdStartRunning(eq(adID)); + } + + @SuppressWarnings("unchecked") + public void testColdStartIllegalArgumentException() throws Exception { + ThreadPool mockThreadPool = mock(ThreadPool.class); + setUpColdStart(mockThreadPool, new ColdStartConfig.Builder().coldStartRunning(false).build()); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(new double[][] { { 1.0 } })); + return null; + }).when(featureQuery).getColdStartData(any(AnomalyDetector.class), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onFailure(new IllegalArgumentException("")); + return null; + }).when(normalModelManager).trainModel(any(AnomalyDetector.class), any(double[][].class), any(ActionListener.class)); + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + mockThreadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + verify(stateManager, times(1)).setException(eq(adID), any(EndRunException.class)); + verify(stateManager, times(1)).markColdStartRunning(eq(adID)); + } + + enum FeatureTestMode { + FEATURE_NOT_AVAILABLE, + ILLEGAL_STATE, + AD_EXCEPTION + } + + @SuppressWarnings("unchecked") + public void featureTestTemplate(FeatureTestMode mode) throws IOException { + if (mode == FeatureTestMode.FEATURE_NOT_AVAILABLE) { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(new SinglePointFeatures(Optional.empty(), Optional.empty())); + return null; + }).when(featureQuery).getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), any(ActionListener.class)); + } else if (mode == FeatureTestMode.ILLEGAL_STATE) { + doThrow(IllegalArgumentException.class) + .when(featureQuery) + .getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), any(ActionListener.class)); + } else if (mode == FeatureTestMode.AD_EXCEPTION) { + doThrow(TimeSeriesException.class) + .when(featureQuery) + .getCurrentFeatures(any(AnomalyDetector.class), anyLong(), anyLong(), any(ActionListener.class)); + } + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + threadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + if (mode == FeatureTestMode.FEATURE_NOT_AVAILABLE) { + AnomalyResultResponse response = listener.actionGet(); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.001); + assertEquals(Double.NaN, response.getConfidence(), 0.001); + assertEquals(Double.NaN, response.getAnomalyScore(), 0.001); + assertThat(response.getFeatures(), is(empty())); + } else if (mode == FeatureTestMode.ILLEGAL_STATE || mode == FeatureTestMode.AD_EXCEPTION) { + assertException(listener, InternalFailure.class); + } + } + + public void testFeatureNotAvailable() throws IOException { + featureTestTemplate(FeatureTestMode.FEATURE_NOT_AVAILABLE); + } + + public void testFeatureIllegalState() throws IOException { + featureTestTemplate(FeatureTestMode.ILLEGAL_STATE); + } + + public void testFeatureAnomalyException() throws IOException { + featureTestTemplate(FeatureTestMode.AD_EXCEPTION); + } + + enum BlockType { + INDEX_BLOCK, + GLOBAL_BLOCK_WRITE, + GLOBAL_BLOCK_READ + } + + private void globalBlockTemplate(BlockType type, String errLogMsg, Settings indexSettings, String indexName) { + ClusterState blockedClusterState = null; + + switch (type) { + case GLOBAL_BLOCK_WRITE: + blockedClusterState = ClusterState + .builder(new ClusterName("test cluster")) + .blocks(ClusterBlocks.builder().addGlobalBlock(IndexMetadata.INDEX_WRITE_BLOCK)) + .build(); + break; + case GLOBAL_BLOCK_READ: + blockedClusterState = ClusterState + .builder(new ClusterName("test cluster")) + .blocks(ClusterBlocks.builder().addGlobalBlock(IndexMetadata.INDEX_READ_BLOCK)) + .build(); + break; + case INDEX_BLOCK: + blockedClusterState = createIndexBlockedState(indexName, indexSettings, null); + break; + default: + break; + } + + ClusterService hackedClusterService = spy(clusterService); + when(hackedClusterService.state()).thenReturn(blockedClusterState); + + // These constructors register handler in transport service + new RCFResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + normalModelManager, + adCircuitBreakerService, + hashRing, + adStats + ); + new ThresholdResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, normalModelManager); + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + hackedClusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + threadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + assertException(listener, TimeSeriesException.class, errLogMsg); + } + + private void globalBlockTemplate(BlockType type, String errLogMsg) { + globalBlockTemplate(type, errLogMsg, null, null); + } + + public void testReadBlock() { + globalBlockTemplate(BlockType.GLOBAL_BLOCK_READ, AnomalyResultTransportAction.READ_WRITE_BLOCKED); + } + + public void testWriteBlock() { + globalBlockTemplate(BlockType.GLOBAL_BLOCK_WRITE, AnomalyResultTransportAction.READ_WRITE_BLOCKED); + } + + public void testIndexReadBlock() { + globalBlockTemplate( + BlockType.INDEX_BLOCK, + AnomalyResultTransportAction.INDEX_READ_BLOCKED, + Settings.builder().put(IndexMetadata.INDEX_BLOCKS_READ_SETTING.getKey(), true).build(), + "test1" + ); + } + + @SuppressWarnings("unchecked") + public void testNullRCFResult() { + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + threadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + AnomalyResultTransportAction.RCFActionListener listener = action.new RCFActionListener( + "123-rcf-0", null, "123", null, mock(ActionListener.class), null, null + ); + listener.onResponse(null); + assertTrue(testAppender.containsMessage(AnomalyResultTransportAction.NULL_RESPONSE)); + } + + @SuppressWarnings("unchecked") + public void testNormalRCFResult() { + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + threadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + ActionListener listener = mock(ActionListener.class); + AnomalyResultTransportAction.RCFActionListener rcfListener = action.new RCFActionListener( + "123-rcf-0", null, "nodeID", detector, listener, null, adID + ); + double[] attribution = new double[] { 1. }; + long totalUpdates = 32; + double grade = 0.5; + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(AnomalyResultResponse.class); + rcfListener + .onResponse(new RCFResultResponse(0.3, 0, 26, attribution, totalUpdates, grade, Version.CURRENT, 0, null, null, null, 1.1)); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(grade, responseCaptor.getValue().getAnomalyGrade(), 1e-10); + } + + @SuppressWarnings("unchecked") + public void testNullPointerRCFResult() { + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + threadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + ActionListener listener = mock(ActionListener.class); + // detector being null causes NullPointerException + AnomalyResultTransportAction.RCFActionListener rcfListener = action.new RCFActionListener( + "123-rcf-0", null, "nodeID", null, listener, null, adID + ); + double[] attribution = new double[] { 1. }; + long totalUpdates = 32; + double grade = 0.5; + ArgumentCaptor failureCaptor = ArgumentCaptor.forClass(Exception.class); + rcfListener + .onResponse(new RCFResultResponse(0.3, 0, 26, attribution, totalUpdates, grade, Version.CURRENT, 0, null, null, null, 1.1)); + verify(listener, times(1)).onFailure(failureCaptor.capture()); + Exception failure = failureCaptor.getValue(); + assertTrue(failure instanceof InternalFailure); + } + + @SuppressWarnings("unchecked") + public void testAllFeaturesDisabled() throws IOException { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onFailure(new EndRunException(adID, CommonMessages.ALL_FEATURES_DISABLED_ERR_MSG, true)); + return null; + }).when(stateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + threadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + assertException(listener, EndRunException.class, CommonMessages.ALL_FEATURES_DISABLED_ERR_MSG); + } + + @SuppressWarnings("unchecked") + public void testEndRunDueToNoTrainingData() { + ThreadPool mockThreadPool = mock(ThreadPool.class); + setUpColdStart(mockThreadPool, new ColdStartConfig.Builder().coldStartRunning(false).build()); + + ModelManager rcfManager = mock(ModelManager.class); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[3]; + listener.onFailure(new IndexNotFoundException(ADCommonName.CHECKPOINT_INDEX_NAME)); + return null; + }).when(rcfManager).getTRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); + + when(stateManager.fetchExceptionAndClear(any(String.class))) + .thenReturn(Optional.of(new EndRunException(adID, "Cannot get training data", false))); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(new double[][] { { 1.0 } })); + return null; + }).when(featureQuery).getColdStartData(any(AnomalyDetector.class), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(null); + return null; + }).when(normalModelManager).trainModel(any(AnomalyDetector.class), any(double[][].class), any(ActionListener.class)); + + // These constructors register handler in transport service + new RCFResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + rcfManager, + adCircuitBreakerService, + hashRing, + adStats + ); + new ThresholdResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, normalModelManager); + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + mockThreadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + assertException(listener, EndRunException.class); + verify(stateManager, times(1)).markColdStartRunning(eq(adID)); + } + + @SuppressWarnings({ "unchecked" }) + public void testColdStartEndRunException() { + ThreadPool mockThreadPool = mock(ThreadPool.class); + setUpColdStart(mockThreadPool, new ColdStartConfig.Builder().coldStartRunning(false).build()); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.empty()); + return null; + }).when(featureQuery).getColdStartData(any(AnomalyDetector.class), any(ActionListener.class)); + + when(stateManager.fetchExceptionAndClear(anyString())) + .thenReturn( + Optional + .of( + new EndRunException( + adID, + CommonMessages.INVALID_SEARCH_QUERY_MSG, + new NoSuchElementException("No value present"), + false + ) + ) + ); + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + mockThreadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + + action.doExecute(null, request, listener); + assertException(listener, EndRunException.class, CommonMessages.INVALID_SEARCH_QUERY_MSG); + verify(featureQuery, times(1)).getColdStartData(any(AnomalyDetector.class), any(ActionListener.class)); + } + + @SuppressWarnings({ "unchecked" }) + public void testColdStartEndRunExceptionNow() { + ThreadPool mockThreadPool = mock(ThreadPool.class); + setUpColdStart(mockThreadPool, new ColdStartConfig.Builder().coldStartRunning(false).build()); + + when(stateManager.fetchExceptionAndClear(anyString())) + .thenReturn( + Optional + .of( + new EndRunException( + adID, + CommonMessages.INVALID_SEARCH_QUERY_MSG, + new NoSuchElementException("No value present"), + true + ) + ) + ); + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + mockThreadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + + action.doExecute(null, request, listener); + assertException(listener, EndRunException.class, CommonMessages.INVALID_SEARCH_QUERY_MSG); + verify(featureQuery, never()).getColdStartData(any(AnomalyDetector.class), any(ActionListener.class)); + } + + @SuppressWarnings({ "unchecked" }) + public void testColdStartBecauseFailtoGetCheckpoint() { + ThreadPool mockThreadPool = mock(ThreadPool.class); + setUpColdStart( + mockThreadPool, + new ColdStartConfig.Builder().getCheckpointException(new IndexNotFoundException(ADCommonName.CHECKPOINT_INDEX_NAME)).build() + ); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.empty()); + return null; + }).when(featureQuery).getColdStartData(any(AnomalyDetector.class), any(ActionListener.class)); + + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + mockThreadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + + action.doExecute(null, request, listener); + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.001); + verify(featureQuery, times(1)).getColdStartData(any(AnomalyDetector.class), any(ActionListener.class)); + } + + @SuppressWarnings({ "unchecked" }) + public void testNoColdStartDueToUnknownException() { + ThreadPool mockThreadPool = mock(ThreadPool.class); + setUpColdStart(mockThreadPool, new ColdStartConfig.Builder().getCheckpointException(new RuntimeException()).build()); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.empty()); + return null; + }).when(featureQuery).getColdStartData(any(AnomalyDetector.class), any(ActionListener.class)); + + AnomalyResultRequest request = new AnomalyResultRequest(adID, 100, 200); + PlainActionFuture listener = new PlainActionFuture<>(); + + AnomalyResultTransportAction action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + mockThreadPool, + NamedXContentRegistry.EMPTY, + adTaskManager + ); + + action.doExecute(null, request, listener); + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.001); + verify(featureQuery, never()).getColdStartData(any(AnomalyDetector.class), any(ActionListener.class)); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java index 562b7de69..0cd9218f0 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java @@ -26,7 +26,7 @@ import org.opensearch.ad.ADIntegTestCase; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.util.ExceptionUtil; -import org.opensearch.common.io.stream.NotSerializableExceptionWrapper; +import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.test.rest.OpenSearchRestTestCase; import org.opensearch.timeseries.TestHelpers; diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java-e b/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java-e new file mode 100644 index 000000000..0cd9218f0 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java-e @@ -0,0 +1,289 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.timeseries.TestHelpers.randomQuery; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.List; +import java.util.Map; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.junit.Before; +import org.opensearch.action.get.GetResponse; +import org.opensearch.ad.ADIntegTestCase; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.util.ExceptionUtil; +import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.test.rest.OpenSearchRestTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +public class AnomalyResultTransportActionTests extends ADIntegTestCase { + private static final Logger LOG = LogManager.getLogger(AnomalyResultTransportActionTests.class); + + private String testIndex; + private Instant testDataTimeStamp; + private long start; + private long end; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + testIndex = "test_data"; + testDataTimeStamp = Instant.now(); + start = testDataTimeStamp.minus(10, ChronoUnit.MINUTES).toEpochMilli(); + end = testDataTimeStamp.plus(10, ChronoUnit.MINUTES).toEpochMilli(); + ingestTestData(); + } + + private void ingestTestData() throws IOException, InterruptedException { + createTestDataIndex(testIndex); + double value = randomDouble(); + String type = randomAlphaOfLength(5); + boolean isError = randomBoolean(); + String message = randomAlphaOfLength(10); + String id = indexDoc( + testIndex, + ImmutableMap + .of(timeField, testDataTimeStamp.toEpochMilli(), "value", value, "type", type, "is_error", isError, "message", message) + ); + GetResponse doc = getDoc(testIndex, id); + Map sourceAsMap = doc.getSourceAsMap(); + assertEquals(testDataTimeStamp.toEpochMilli(), sourceAsMap.get(timeField)); + assertEquals(value, sourceAsMap.get("value")); + assertEquals(type, sourceAsMap.get("type")); + assertEquals(isError, sourceAsMap.get("is_error")); + assertEquals(message, sourceAsMap.get("message")); + createDetectorIndex(); + } + + public void testFeatureQueryWithTermsAggregation() throws IOException { + String adId = createDetectorWithFeatureAgg("{\"test\":{\"terms\":{\"field\":\"type\"}}}"); + assertErrorMessage(adId, "Failed to parse aggregation"); + } + + public void testFeatureWithSumOfTextField() throws IOException { + String adId = createDetectorWithFeatureAgg("{\"test\":{\"sum\":{\"field\":\"message\"}}}"); + assertErrorMessage(adId, "Text fields are not optimised for operations"); + } + + public void testFeatureWithSumOfTypeField() throws IOException { + String adId = createDetectorWithFeatureAgg("{\"test\":{\"sum\":{\"field\":\"type\"}}}"); + assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [sum]"); + } + + public void testFeatureWithMaxOfTextField() throws IOException { + String adId = createDetectorWithFeatureAgg("{\"test\":{\"max\":{\"field\":\"message\"}}}"); + assertErrorMessage(adId, "Text fields are not optimised for operations"); + } + + public void testFeatureWithMaxOfTypeField() throws IOException { + String adId = createDetectorWithFeatureAgg("{\"test\":{\"max\":{\"field\":\"type\"}}}"); + assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [max]"); + } + + public void testFeatureWithMinOfTextField() throws IOException { + String adId = createDetectorWithFeatureAgg("{\"test\":{\"min\":{\"field\":\"message\"}}}"); + assertErrorMessage(adId, "Text fields are not optimised for operations"); + } + + public void testFeatureWithMinOfTypeField() throws IOException { + String adId = createDetectorWithFeatureAgg("{\"test\":{\"min\":{\"field\":\"type\"}}}"); + assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [min]"); + } + + public void testFeatureWithAvgOfTextField() throws IOException { + String adId = createDetectorWithFeatureAgg("{\"test\":{\"avg\":{\"field\":\"message\"}}}"); + assertErrorMessage(adId, "Text fields are not optimised for operations"); + } + + public void testFeatureWithAvgOfTypeField() throws IOException { + String adId = createDetectorWithFeatureAgg("{\"test\":{\"avg\":{\"field\":\"type\"}}}"); + assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [avg]"); + } + + public void testFeatureWithCountOfTextField() throws IOException { + String adId = createDetectorWithFeatureAgg("{\"test\":{\"value_count\":{\"field\":\"message\"}}}"); + assertErrorMessage(adId, "Text fields are not optimised for operations"); + } + + public void testFeatureWithCardinalityOfTextField() throws IOException { + String adId = createDetectorWithFeatureAgg("{\"test\":{\"cardinality\":{\"field\":\"message\"}}}"); + assertErrorMessage(adId, "Text fields are not optimised for operations"); + } + + public void testFeatureQueryWithTermsAggregationForHCDetector() throws IOException { + String adId = createDetectorWithFeatureAgg("{\"test\":{\"terms\":{\"field\":\"type\"}}}", true); + assertErrorMessage(adId, "Failed to parse aggregation", true); + } + + public void testFeatureWithSumOfTextFieldForHCDetector() throws IOException { + String adId = createDetectorWithFeatureAgg("{\"test\":{\"sum\":{\"field\":\"message\"}}}", true); + assertErrorMessage(adId, "Text fields are not optimised for operations", true); + } + + public void testFeatureWithSumOfTypeFieldForHCDetector() throws IOException { + String adId = createDetectorWithFeatureAgg("{\"test\":{\"sum\":{\"field\":\"type\"}}}", true); + assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [sum]", true); + } + + public void testFeatureWithMaxOfTextFieldForHCDetector() throws IOException { + String adId = createDetectorWithFeatureAgg("{\"test\":{\"max\":{\"field\":\"message\"}}}", true); + assertErrorMessage(adId, "Text fields are not optimised for operations", true); + } + + public void testFeatureWithMaxOfTypeFieldForHCDetector() throws IOException { + String adId = createDetectorWithFeatureAgg("{\"test\":{\"max\":{\"field\":\"type\"}}}", true); + assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [max]", true); + } + + public void testFeatureWithMinOfTextFieldForHCDetector() throws IOException { + String adId = createDetectorWithFeatureAgg("{\"test\":{\"min\":{\"field\":\"message\"}}}", true); + assertErrorMessage(adId, "Text fields are not optimised for operations", true); + } + + public void testFeatureWithMinOfTypeFieldForHCDetector() throws IOException { + String adId = createDetectorWithFeatureAgg("{\"test\":{\"min\":{\"field\":\"type\"}}}", true); + assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [min]", true); + } + + public void testFeatureWithAvgOfTextFieldForHCDetector() throws IOException { + String adId = createDetectorWithFeatureAgg("{\"test\":{\"avg\":{\"field\":\"message\"}}}", true); + assertErrorMessage(adId, "Text fields are not optimised for operations", true); + } + + public void testFeatureWithAvgOfTypeFieldForHCDetector() throws IOException { + String adId = createDetectorWithFeatureAgg("{\"test\":{\"avg\":{\"field\":\"type\"}}}", true); + assertErrorMessage(adId, "Field [type] of type [keyword] is not supported for aggregation [avg]", true); + } + + public void testFeatureWithCountOfTextFieldForHCDetector() throws IOException { + String adId = createDetectorWithFeatureAgg("{\"test\":{\"value_count\":{\"field\":\"message\"}}}", true); + assertErrorMessage(adId, "Text fields are not optimised for operations", true); + } + + public void testFeatureWithCardinalityOfTextFieldForHCDetector() throws IOException { + String adId = createDetectorWithFeatureAgg("{\"test\":{\"cardinality\":{\"field\":\"message\"}}}", true); + assertErrorMessage(adId, "Text fields are not optimised for operations", true); + } + + private String createDetectorWithFeatureAgg(String aggQuery) throws IOException { + return createDetectorWithFeatureAgg(aggQuery, false); + } + + private String createDetectorWithFeatureAgg(String aggQuery, boolean hcDetector) throws IOException { + AggregationBuilder aggregationBuilder = TestHelpers.parseAggregation(aggQuery); + Feature feature = new Feature(randomAlphaOfLength(5), randomAlphaOfLength(10), true, aggregationBuilder); + AnomalyDetector detector = hcDetector + ? randomHCDetector(ImmutableList.of(testIndex), ImmutableList.of(feature)) + : randomDetector(ImmutableList.of(testIndex), ImmutableList.of(feature)); + String adId = createDetector(detector); + return adId; + } + + private AnomalyDetector randomDetector(List indices, List features) throws IOException { + return new AnomalyDetector( + randomAlphaOfLength(10), + randomLong(), + randomAlphaOfLength(20), + randomAlphaOfLength(30), + timeField, + indices, + features, + randomQuery("{\"bool\":{\"filter\":[{\"exists\":{\"field\":\"value\"}}]}}"), + new IntervalTimeConfiguration(OpenSearchRestTestCase.randomLongBetween(1, 5), ChronoUnit.MINUTES), + new IntervalTimeConfiguration(OpenSearchRestTestCase.randomLongBetween(1, 5), ChronoUnit.MINUTES), + 8, + null, + randomInt(), + Instant.now(), + null, + null, + null, + TestHelpers.randomImputationOption() + ); + } + + private AnomalyDetector randomHCDetector(List indices, List features) throws IOException { + return new AnomalyDetector( + randomAlphaOfLength(10), + randomLong(), + randomAlphaOfLength(20), + randomAlphaOfLength(30), + timeField, + indices, + features, + randomQuery("{\"bool\":{\"filter\":[{\"exists\":{\"field\":\"value\"}}]}}"), + new IntervalTimeConfiguration(OpenSearchRestTestCase.randomLongBetween(1, 5), ChronoUnit.MINUTES), + new IntervalTimeConfiguration(OpenSearchRestTestCase.randomLongBetween(1, 5), ChronoUnit.MINUTES), + 8, + null, + randomInt(), + Instant.now(), + ImmutableList.of(categoryField), + null, + null, + TestHelpers.randomImputationOption() + ); + } + + private void assertErrorMessage(String adId, String errorMessage, boolean hcDetector) { + AnomalyResultRequest resultRequest = new AnomalyResultRequest(adId, start, end); + try { + Thread.sleep(1000); // sleep some time to build AD version hash ring + } catch (InterruptedException e) { + throw new RuntimeException("Fail to sleep before calling AD result action"); + } + // wait at most 20 seconds + int numberofTries = 40; + Exception e = null; + if (hcDetector) { + while (numberofTries-- > 0) { + try { + // HCAD records failures asynchronously. Before a failure is recorded, HCAD returns immediately without failure. + client().execute(AnomalyResultAction.INSTANCE, resultRequest).actionGet(30_000); + Thread.sleep(500); + } catch (Exception exp) { + e = exp; + break; + } + } + } else { + e = expectThrowsAnyOf( + ImmutableList.of(NotSerializableExceptionWrapper.class, TimeSeriesException.class), + () -> client().execute(AnomalyResultAction.INSTANCE, resultRequest).actionGet(30_000) + ); + } + String stackErrorMessage = ExceptionUtil.getErrorMessage(e); + assertTrue( + "Unexpected error: " + e.getMessage(), + stackErrorMessage.contains(errorMessage) + || stackErrorMessage.contains("node is not available") + || stackErrorMessage.contains("AD memory circuit is broken") + ); + } + + private void assertErrorMessage(String adId, String errorMessage) { + assertErrorMessage(adId, errorMessage, false); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/CronTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/CronTransportActionTests.java index bcb49d81b..a65c35839 100644 --- a/src/test/java/org/opensearch/ad/transport/CronTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/CronTransportActionTests.java @@ -36,8 +36,8 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.threadpool.ThreadPool; diff --git a/src/test/java/org/opensearch/ad/transport/CronTransportActionTests.java-e b/src/test/java/org/opensearch/ad/transport/CronTransportActionTests.java-e new file mode 100644 index 000000000..15085f9de --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/CronTransportActionTests.java-e @@ -0,0 +1,130 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.function.Function; + +import org.junit.Assert; +import org.junit.Before; +import org.opensearch.Version; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.common.exception.JsonPathNotFoundException; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.ml.EntityColdStarter; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Strings; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.transport.TransportService; + +import test.org.opensearch.ad.util.JsonDeserializer; + +import com.google.gson.JsonElement; + +public class CronTransportActionTests extends AbstractTimeSeriesTest { + private CronTransportAction action; + private String localNodeID; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + ThreadPool threadPool = mock(ThreadPool.class); + + ClusterService clusterService = mock(ClusterService.class); + localNodeID = "foo"; + when(clusterService.localNode()).thenReturn(new DiscoveryNode(localNodeID, buildNewFakeTransportAddress(), Version.CURRENT)); + when(clusterService.getClusterName()).thenReturn(new ClusterName("test")); + + TransportService transportService = mock(TransportService.class); + ActionFilters actionFilters = mock(ActionFilters.class); + NodeStateManager tarnsportStatemanager = mock(NodeStateManager.class); + ModelManager modelManager = mock(ModelManager.class); + FeatureManager featureManager = mock(FeatureManager.class); + CacheProvider cacheProvider = mock(CacheProvider.class); + EntityCache entityCache = mock(EntityCache.class); + EntityColdStarter entityColdStarter = mock(EntityColdStarter.class); + when(cacheProvider.get()).thenReturn(entityCache); + ADTaskManager adTaskManager = mock(ADTaskManager.class); + + action = new CronTransportAction( + threadPool, + clusterService, + transportService, + actionFilters, + tarnsportStatemanager, + modelManager, + featureManager, + cacheProvider, + entityColdStarter, + adTaskManager + ); + } + + public void testNormal() throws IOException, JsonPathNotFoundException { + CronRequest request = new CronRequest(); + + CronNodeRequest nodeRequest = new CronNodeRequest(); + BytesStreamOutput nodeRequestOut = new BytesStreamOutput(); + nodeRequestOut.setVersion(Version.V_2_0_0); + nodeRequest.writeTo(nodeRequestOut); + StreamInput siNode = nodeRequestOut.bytes().streamInput(); + siNode.setVersion(Version.V_2_0_0); + + CronNodeRequest nodeResponseRead = new CronNodeRequest(siNode); + + CronNodeResponse nodeResponse1 = action.nodeOperation(nodeResponseRead); + CronNodeResponse nodeResponse2 = action.nodeOperation(new CronNodeRequest()); + + CronResponse response = action.newResponse(request, Arrays.asList(nodeResponse1, nodeResponse2), Collections.emptyList()); + + assertEquals(2, response.getNodes().size()); + assertTrue(!response.hasFailures()); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject(); + + String json = Strings.toString(builder); + Function function = (s) -> { + try { + return JsonDeserializer.getTextValue(s, CronNodeResponse.NODE_ID); + } catch (Exception e) { + Assert.fail(e.getMessage()); + } + return null; + }; + assertArrayEquals( + JsonDeserializer.getArrayValue(json, function, CronResponse.NODES_JSON_KEY), + new String[] { localNodeID, localNodeID } + ); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorActionTests.java b/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorActionTests.java index 809033ebc..ca7fae8ba 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorActionTests.java @@ -30,9 +30,9 @@ import org.opensearch.ad.task.ADTaskManager; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.transport.TransportService; diff --git a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorActionTests.java-e b/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorActionTests.java-e new file mode 100644 index 000000000..08ebf81a8 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorActionTests.java-e @@ -0,0 +1,116 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.transport.TransportService; + +public class DeleteAnomalyDetectorActionTests extends OpenSearchIntegTestCase { + private DeleteAnomalyDetectorTransportAction action; + private ActionListener response; + private ADTaskManager adTaskManager; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + ClusterService clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES))) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + adTaskManager = mock(ADTaskManager.class); + action = new DeleteAnomalyDetectorTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + client(), + clusterService, + Settings.EMPTY, + xContentRegistry(), + adTaskManager + ); + response = new ActionListener() { + @Override + public void onResponse(DeleteResponse deleteResponse) { + Assert.assertTrue(true); + } + + @Override + public void onFailure(Exception e) { + Assert.assertTrue(true); + } + }; + } + + @Test + public void testStatsAction() { + Assert.assertNotNull(DeleteAnomalyDetectorAction.INSTANCE.name()); + Assert.assertEquals(DeleteAnomalyDetectorAction.INSTANCE.name(), DeleteAnomalyDetectorAction.NAME); + } + + @Test + public void testDeleteRequest() throws IOException { + DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + BytesStreamOutput out = new BytesStreamOutput(); + request.writeTo(out); + StreamInput input = out.bytes().streamInput(); + DeleteAnomalyDetectorRequest newRequest = new DeleteAnomalyDetectorRequest(input); + Assert.assertEquals(request.getDetectorID(), newRequest.getDetectorID()); + Assert.assertNull(newRequest.validate()); + } + + @Test + public void testEmptyDeleteRequest() { + DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest(""); + ActionRequestValidationException exception = request.validate(); + Assert.assertNotNull(exception); + } + + @Test + public void testTransportActionWithAdIndex() { + // DeleteResponse is not called because detector ID will not exist + createIndex(".opendistro-anomaly-detector-jobs"); + DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + action.doExecute(mock(Task.class), request, response); + } + + @Test + public void testTransportActionWithoutAdIndex() throws IOException { + // DeleteResponse is not called because detector ID will not exist + DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + action.doExecute(mock(Task.class), request, response); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTests.java b/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTests.java index d2c3e634c..7b67843f1 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTests.java +++ b/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTests.java @@ -46,9 +46,9 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.metadata.Metadata; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.index.get.GetResult; import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; diff --git a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTests.java-e b/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTests.java-e new file mode 100644 index 000000000..def276b06 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTests.java-e @@ -0,0 +1,312 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.transport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; + +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.index.get.GetResult; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.transport.Transport; +import org.opensearch.transport.TransportService; + +public class DeleteAnomalyDetectorTests extends AbstractTimeSeriesTest { + private DeleteAnomalyDetectorTransportAction action; + private TransportService transportService; + private ActionFilters actionFilters; + private Client client; + private ADTaskManager adTaskManager; + private PlainActionFuture future; + private DeleteResponse deleteResponse; + private GetResponse getResponse; + ClusterService clusterService; + private AnomalyDetectorJob jobParameter; + + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(EntityProfileTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + + @Override + public void setUp() throws Exception { + super.setUp(); + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES))) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + transportService = new TransportService( + Settings.EMPTY, + mock(Transport.class), + null, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + x -> null, + null, + Collections.emptySet() + ); + + client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + + actionFilters = mock(ActionFilters.class); + adTaskManager = mock(ADTaskManager.class); + action = new DeleteAnomalyDetectorTransportAction( + transportService, + actionFilters, + client, + clusterService, + Settings.EMPTY, + xContentRegistry(), + adTaskManager + ); + + jobParameter = mock(AnomalyDetectorJob.class); + when(jobParameter.getName()).thenReturn(randomAlphaOfLength(10)); + IntervalSchedule schedule = new IntervalSchedule(Instant.now(), 1, ChronoUnit.MINUTES); + when(jobParameter.getSchedule()).thenReturn(schedule); + when(jobParameter.getWindowDelay()).thenReturn(new IntervalTimeConfiguration(10, ChronoUnit.SECONDS)); + } + + public void testDeleteADTransportAction_FailDeleteResponse() { + future = mock(PlainActionFuture.class); + DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + setupMocks(true, true, false, false); + + action.doExecute(mock(Task.class), request, future); + verify(adTaskManager).deleteADTasks(eq("1234"), any(), any()); + verify(client, times(1)).delete(any(), any()); + verify(future).onFailure(any(OpenSearchStatusException.class)); + } + + public void testDeleteADTransportAction_NullAnomalyDetector() { + future = mock(PlainActionFuture.class); + DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + setupMocks(true, false, false, false); + + action.doExecute(mock(Task.class), request, future); + verify(adTaskManager).deleteADTasks(eq("1234"), any(), any()); + verify(client, times(3)).delete(any(), any()); + } + + public void testDeleteADTransportAction_DeleteResponseException() { + future = mock(PlainActionFuture.class); + DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + setupMocks(true, false, true, false); + + action.doExecute(mock(Task.class), request, future); + verify(adTaskManager).deleteADTasks(eq("1234"), any(), any()); + verify(client, times(1)).delete(any(), any()); + verify(future).onFailure(any(RuntimeException.class)); + } + + public void testDeleteADTransportAction_LatestDetectorLevelTask() { + when(clusterService.state()).thenReturn(createClusterState()); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + Consumer> consumer = (Consumer>) args[2]; + ADTask adTask = ADTask.builder().state("RUNNING").build(); + consumer.accept(Optional.of(adTask)); + return null; + }).when(adTaskManager).getAndExecuteOnLatestDetectorLevelTask(eq("1234"), any(), any(), eq(transportService), eq(true), any()); + + future = mock(PlainActionFuture.class); + DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + setupMocks(false, false, false, false); + + action.doExecute(mock(Task.class), request, future); + verify(future).onFailure(any(OpenSearchStatusException.class)); + } + + public void testDeleteADTransportAction_JobRunning() { + when(clusterService.state()).thenReturn(createClusterState()); + future = mock(PlainActionFuture.class); + DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + setupMocks(false, false, false, false); + + action.doExecute(mock(Task.class), request, future); + verify(future).onFailure(any(RuntimeException.class)); + } + + public void testDeleteADTransportAction_GetResponseException() { + when(clusterService.state()).thenReturn(createClusterState()); + future = mock(PlainActionFuture.class); + DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); + setupMocks(false, false, false, true); + + action.doExecute(mock(Task.class), request, future); + verify(client).get(any(), any()); + verify(client).get(any(), any()); + } + + private ClusterState createClusterState() { + Map immutableOpenMap = new HashMap<>(); + immutableOpenMap + .put( + CommonName.JOB_INDEX, + IndexMetadata + .builder("test") + .settings( + Settings + .builder() + .put("index.number_of_shards", 1) + .put("index.number_of_replicas", 1) + .put("index.version.created", Version.CURRENT.id) + ) + .build() + ); + Metadata metaData = Metadata.builder().indices(immutableOpenMap).build(); + ClusterState clusterState = new ClusterState( + new ClusterName("test_name"), + 1l, + "uuid", + metaData, + null, + null, + null, + new HashMap<>(), + 1, + true + ); + return clusterState; + } + + private void setupMocks( + boolean nullAnomalyDetectorResponse, + boolean failDeleteDeleteResponse, + boolean deleteResponseException, + boolean getResponseFailure + ) { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + Consumer> consumer = (Consumer>) args[1]; + if (nullAnomalyDetectorResponse) { + consumer.accept(Optional.empty()); + } else { + AnomalyDetector ad = mock(AnomalyDetector.class); + consumer.accept(Optional.of(ad)); + } + return null; + }).when(adTaskManager).getDetector(any(), any(), any()); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ExecutorFunction function = (ExecutorFunction) args[1]; + + function.execute(); + return null; + }).when(adTaskManager).deleteADTasks(eq("1234"), any(), any()); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + deleteResponse = mock(DeleteResponse.class); + if (deleteResponseException) { + listener.onFailure(new RuntimeException("Failed to delete anomaly detector job")); + return null; + } + if (failDeleteDeleteResponse) { + doReturn(DocWriteResponse.Result.CREATED).when(deleteResponse).getResult(); + } else { + doReturn(DocWriteResponse.Result.DELETED).when(deleteResponse).getResult(); + } + listener.onResponse(deleteResponse); + return null; + }).when(client).delete(any(), any()); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + if (getResponseFailure) { + listener.onFailure(new RuntimeException("Fail to get anomaly detector job")); + return null; + } + getResponse = new GetResponse( + new GetResult( + CommonName.JOB_INDEX, + "id", + UNASSIGNED_SEQ_NO, + 0, + -1, + true, + BytesReference + .bytes( + new AnomalyDetectorJob( + "1234", + jobParameter.getSchedule(), + jobParameter.getWindowDelay(), + true, + Instant.now().minusSeconds(60), + Instant.now(), + Instant.now(), + 60L, + TestHelpers.randomUser(), + jobParameter.getCustomResultIndex() + ).toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS) + ), + Collections.emptyMap(), + Collections.emptyMap() + ) + ); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportActionTests.java-e b/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportActionTests.java-e new file mode 100644 index 000000000..ac81ecf25 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportActionTests.java-e @@ -0,0 +1,67 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; + +import org.junit.Before; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.ad.HistoricalAnalysisIntegTestCase; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.model.Feature; + +import com.google.common.collect.ImmutableList; + +@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 2) +public class DeleteAnomalyDetectorTransportActionTests extends HistoricalAnalysisIntegTestCase { + private Instant startTime; + private Instant endTime; + private String type = "error"; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + startTime = Instant.now().minus(10, ChronoUnit.DAYS); + endTime = Instant.now(); + ingestTestData(testIndex, startTime, detectionIntervalInMinutes, type, 2000); + createDetectorIndex(); + } + + public void testDeleteAnomalyDetectorWithoutFeature() throws IOException { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(null); + testDeleteDetector(detector); + } + + public void testDeleteAnomalyDetectorWithoutEnabledFeature() throws IOException { + Feature feature = TestHelpers.randomFeature(false); + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableList.of(feature)); + testDeleteDetector(detector); + } + + public void testDeleteAnomalyDetectorWithEnabledFeature() throws IOException { + Feature feature = TestHelpers.randomFeature(true); + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableList.of(feature)); + testDeleteDetector(detector); + } + + private void testDeleteDetector(AnomalyDetector detector) throws IOException { + String detectorId = createDetector(detector); + DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest(detectorId); + DeleteResponse deleteResponse = client().execute(DeleteAnomalyDetectorAction.INSTANCE, request).actionGet(10000); + assertEquals("deleted", deleteResponse.getResult().getLowercase()); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyResultsTransportActionTests.java-e b/src/test/java/org/opensearch/ad/transport/DeleteAnomalyResultsTransportActionTests.java-e new file mode 100644 index 000000000..5653a577c --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/DeleteAnomalyResultsTransportActionTests.java-e @@ -0,0 +1,52 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.constant.ADCommonName.ANOMALY_RESULT_INDEX_ALIAS; +import static org.opensearch.timeseries.TestHelpers.matchAllRequest; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +import org.junit.Ignore; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.HistoricalAnalysisIntegTestCase; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.reindex.BulkByScrollResponse; +import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.timeseries.TestHelpers; + +public class DeleteAnomalyResultsTransportActionTests extends HistoricalAnalysisIntegTestCase { + + // TODO: fix flaky test + @Ignore + public void testDeleteADResultAction() throws IOException, InterruptedException { + createADResultIndex(); + String adResultId = createADResult(TestHelpers.randomAnomalyDetectResult()); + + SearchResponse searchResponse = client().execute(SearchAnomalyResultAction.INSTANCE, matchAllRequest()).actionGet(10000); + assertEquals(1, searchResponse.getInternalResponse().hits().getTotalHits().value); + + assertEquals(adResultId, searchResponse.getInternalResponse().hits().getAt(0).getId()); + DeleteByQueryRequest deleteByQueryRequest = new DeleteByQueryRequest(ANOMALY_RESULT_INDEX_ALIAS); + deleteByQueryRequest.setQuery(new BoolQueryBuilder().filter(new MatchAllQueryBuilder())); + BulkByScrollResponse deleteADResultResponse = client() + .execute(DeleteAnomalyResultsAction.INSTANCE, deleteByQueryRequest) + .actionGet(20000); + waitUntil(() -> { + SearchResponse response = client().execute(SearchAnomalyResultAction.INSTANCE, matchAllRequest()).actionGet(10000); + return response.getInternalResponse().hits().getTotalHits().value == 0; + }, 90, TimeUnit.SECONDS); + assertEquals(1, deleteADResultResponse.getDeleted()); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/DeleteITTests.java b/src/test/java/org/opensearch/ad/transport/DeleteITTests.java index b9c130041..aeb0b7165 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteITTests.java +++ b/src/test/java/org/opensearch/ad/transport/DeleteITTests.java @@ -18,18 +18,18 @@ import org.opensearch.action.ActionFuture; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.ad.ADIntegTestCase; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.plugins.Plugin; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; public class DeleteITTests extends ADIntegTestCase { @Override protected Collection> nodePlugins() { - return Collections.singletonList(AnomalyDetectorPlugin.class); + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); } protected Collection> transportClientPlugins() { - return Collections.singletonList(AnomalyDetectorPlugin.class); + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); } public void testNormalStopDetector() throws ExecutionException, InterruptedException { diff --git a/src/test/java/org/opensearch/ad/transport/DeleteITTests.java-e b/src/test/java/org/opensearch/ad/transport/DeleteITTests.java-e new file mode 100644 index 000000000..aeb0b7165 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/DeleteITTests.java-e @@ -0,0 +1,68 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.util.Collection; +import java.util.Collections; +import java.util.concurrent.ExecutionException; + +import org.opensearch.action.ActionFuture; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.ADIntegTestCase; +import org.opensearch.plugins.Plugin; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +public class DeleteITTests extends ADIntegTestCase { + + @Override + protected Collection> nodePlugins() { + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); + } + + protected Collection> transportClientPlugins() { + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); + } + + public void testNormalStopDetector() throws ExecutionException, InterruptedException { + StopDetectorRequest request = new StopDetectorRequest().adID("123"); + + ActionFuture future = client().execute(StopDetectorAction.INSTANCE, request); + + StopDetectorResponse response = future.get(); + assertTrue(response.success()); + } + + public void testNormalDeleteModel() throws ExecutionException, InterruptedException { + DeleteModelRequest request = new DeleteModelRequest("123"); + + ActionFuture future = client().execute(DeleteModelAction.INSTANCE, request); + + DeleteModelResponse response = future.get(); + assertTrue(!response.hasFailures()); + } + + public void testEmptyIDDeleteModel() throws ExecutionException, InterruptedException { + DeleteModelRequest request = new DeleteModelRequest(""); + + ActionFuture future = client().execute(DeleteModelAction.INSTANCE, request); + + expectThrows(ActionRequestValidationException.class, () -> future.actionGet()); + } + + public void testEmptyIDStopDetector() throws ExecutionException, InterruptedException { + StopDetectorRequest request = new StopDetectorRequest(); + + ActionFuture future = client().execute(StopDetectorAction.INSTANCE, request); + + expectThrows(ActionRequestValidationException.class, () -> future.actionGet()); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/DeleteModelTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/DeleteModelTransportActionTests.java index 9246a63ce..fd74a2802 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteModelTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/DeleteModelTransportActionTests.java @@ -41,8 +41,8 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.threadpool.ThreadPool; diff --git a/src/test/java/org/opensearch/ad/transport/DeleteModelTransportActionTests.java-e b/src/test/java/org/opensearch/ad/transport/DeleteModelTransportActionTests.java-e new file mode 100644 index 000000000..a05859b1d --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/DeleteModelTransportActionTests.java-e @@ -0,0 +1,140 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.function.Function; + +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.Before; +import org.opensearch.Version; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.common.exception.JsonPathNotFoundException; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.ml.EntityColdStarter; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.task.ADTaskCacheManager; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Strings; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.transport.TransportService; + +import test.org.opensearch.ad.util.JsonDeserializer; + +import com.google.gson.JsonElement; + +public class DeleteModelTransportActionTests extends AbstractTimeSeriesTest { + private DeleteModelTransportAction action; + private String localNodeID; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + ThreadPool threadPool = mock(ThreadPool.class); + + ClusterService clusterService = mock(ClusterService.class); + localNodeID = "foo"; + when(clusterService.localNode()).thenReturn(new DiscoveryNode(localNodeID, buildNewFakeTransportAddress(), Version.CURRENT)); + when(clusterService.getClusterName()).thenReturn(new ClusterName("test")); + + TransportService transportService = mock(TransportService.class); + ActionFilters actionFilters = mock(ActionFilters.class); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + ModelManager modelManager = mock(ModelManager.class); + FeatureManager featureManager = mock(FeatureManager.class); + CacheProvider cacheProvider = mock(CacheProvider.class); + EntityCache entityCache = mock(EntityCache.class); + when(cacheProvider.get()).thenReturn(entityCache); + ADTaskCacheManager adTaskCacheManager = mock(ADTaskCacheManager.class); + EntityColdStarter coldStarter = mock(EntityColdStarter.class); + + action = new DeleteModelTransportAction( + threadPool, + clusterService, + transportService, + actionFilters, + nodeStateManager, + modelManager, + featureManager, + cacheProvider, + adTaskCacheManager, + coldStarter + ); + } + + public void testNormal() throws IOException, JsonPathNotFoundException { + DeleteModelRequest request = new DeleteModelRequest("123"); + assertThat(request.validate(), is(nullValue())); + + DeleteModelNodeRequest nodeRequest = new DeleteModelNodeRequest(request); + BytesStreamOutput nodeRequestOut = new BytesStreamOutput(); + nodeRequestOut.setVersion(Version.CURRENT); + nodeRequest.writeTo(nodeRequestOut); + StreamInput siNode = nodeRequestOut.bytes().streamInput(); + + DeleteModelNodeRequest nodeResponseRead = new DeleteModelNodeRequest(siNode); + + DeleteModelNodeResponse nodeResponse1 = action.nodeOperation(nodeResponseRead); + DeleteModelNodeResponse nodeResponse2 = action.nodeOperation(new DeleteModelNodeRequest(request)); + + DeleteModelResponse response = action.newResponse(request, Arrays.asList(nodeResponse1, nodeResponse2), Collections.emptyList()); + + assertEquals(2, response.getNodes().size()); + assertTrue(!response.hasFailures()); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject(); + + String json = Strings.toString(builder); + Function function = (s) -> { + try { + return JsonDeserializer.getTextValue(s, CronNodeResponse.NODE_ID); + } catch (Exception e) { + Assert.fail(e.getMessage()); + } + return null; + }; + assertArrayEquals( + JsonDeserializer.getArrayValue(json, function, CronResponse.NODES_JSON_KEY), + new String[] { localNodeID, localNodeID } + ); + } + + public void testEmptyDetectorID() { + ActionRequestValidationException e = new DeleteModelRequest().validate(); + assertThat(e.validationErrors(), Matchers.hasItem(ADCommonMessages.AD_ID_MISSING_MSG)); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/DeleteTests.java b/src/test/java/org/opensearch/ad/transport/DeleteTests.java index a8692f232..619ee6bb2 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteTests.java +++ b/src/test/java/org/opensearch/ad/transport/DeleteTests.java @@ -49,10 +49,10 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; import org.opensearch.common.transport.TransportAddress; import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.reindex.BulkByScrollResponse; diff --git a/src/test/java/org/opensearch/ad/transport/DeleteTests.java-e b/src/test/java/org/opensearch/ad/transport/DeleteTests.java-e new file mode 100644 index 000000000..5bafc6f07 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/DeleteTests.java-e @@ -0,0 +1,254 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.function.Supplier; + +import org.hamcrest.Matchers; +import org.junit.Before; +import org.opensearch.OpenSearchException; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.ad.common.exception.JsonPathNotFoundException; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Strings; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.transport.TransportAddress; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.reindex.BulkByScrollResponse; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.transport.TransportService; + +import test.org.opensearch.ad.util.ClusterCreation; +import test.org.opensearch.ad.util.JsonDeserializer; + +public class DeleteTests extends AbstractTimeSeriesTest { + private DeleteModelResponse response; + private List failures; + private List deleteModelResponse; + private String node1, node2, nodename1, nodename2; + private Client client; + private ClusterService clusterService; + private TransportService transportService; + private ThreadPool threadPool; + private ActionFilters actionFilters; + private Task task; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + node1 = "node1"; + node2 = "node2"; + nodename1 = "nodename1"; + nodename2 = "nodename2"; + DiscoveryNode discoveryNode1 = new DiscoveryNode( + nodename1, + node1, + new TransportAddress(TransportAddress.META_ADDRESS, 9300), + emptyMap(), + emptySet(), + Version.CURRENT + ); + DiscoveryNode discoveryNode2 = new DiscoveryNode( + nodename2, + node2, + new TransportAddress(TransportAddress.META_ADDRESS, 9301), + emptyMap(), + emptySet(), + Version.CURRENT + ); + List discoveryNodes = new ArrayList(2); + discoveryNodes.add(discoveryNode1); + discoveryNodes.add(discoveryNode2); + + DeleteModelNodeResponse nodeResponse1 = new DeleteModelNodeResponse(discoveryNode1); + DeleteModelNodeResponse nodeResponse2 = new DeleteModelNodeResponse(discoveryNode2); + + deleteModelResponse = new ArrayList<>(); + + deleteModelResponse.add(nodeResponse1); + deleteModelResponse.add(nodeResponse2); + + failures = new ArrayList<>(); + failures.add(new FailedNodeException("node3", "blah", new OpenSearchException("foo"))); + + response = new DeleteModelResponse(new ClusterName("Cluster"), deleteModelResponse, failures); + + clusterService = mock(ClusterService.class); + when(clusterService.localNode()).thenReturn(discoveryNode1); + when(clusterService.state()) + .thenReturn(ClusterCreation.state(new ClusterName("test"), discoveryNode2, discoveryNode1, discoveryNodes)); + + transportService = mock(TransportService.class); + threadPool = mock(ThreadPool.class); + actionFilters = mock(ActionFilters.class); + Settings settings = Settings.builder().put("plugins.anomaly_detection.request_timeout", TimeValue.timeValueSeconds(10)).build(); + task = mock(Task.class); + when(task.getId()).thenReturn(1000L); + client = mock(Client.class); + when(client.settings()).thenReturn(settings); + when(client.threadPool()).thenReturn(threadPool); + } + + public void testSerialzationResponse() throws IOException { + + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + DeleteModelResponse readResponse = DeleteModelAction.INSTANCE.getResponseReader().read(streamInput); + assertTrue(readResponse.hasFailures()); + + assertEquals(failures.size(), readResponse.failures().size()); + assertEquals(deleteModelResponse.size(), readResponse.getNodes().size()); + } + + public void testEmptyIDDeleteModel() { + ActionRequestValidationException e = new DeleteModelRequest("").validate(); + assertThat(e.validationErrors(), Matchers.hasItem(ADCommonMessages.AD_ID_MISSING_MSG)); + } + + public void testEmptyIDStopDetector() { + ActionRequestValidationException e = new StopDetectorRequest().validate(); + assertThat(e.validationErrors(), hasItem(ADCommonMessages.AD_ID_MISSING_MSG)); + } + + public void testValidIDStopDetector() { + ActionRequestValidationException e = new StopDetectorRequest().adID("foo").validate(); + assertThat(e, is(nullValue())); + } + + public void testSerialzationRequestDeleteModel() throws IOException { + DeleteModelRequest request = new DeleteModelRequest("123"); + BytesStreamOutput output = new BytesStreamOutput(); + request.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + DeleteModelRequest readRequest = new DeleteModelRequest(streamInput); + assertThat(request.getAdID(), equalTo(readRequest.getAdID())); + } + + public void testSerialzationRequestStopDetector() throws IOException { + StopDetectorRequest request = new StopDetectorRequest().adID("123"); + BytesStreamOutput output = new BytesStreamOutput(); + request.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + StopDetectorRequest readRequest = new StopDetectorRequest(streamInput); + assertThat(request.getAdID(), equalTo(readRequest.getAdID())); + } + + public void testJsonRequestTemplate(R request, Supplier requestSupplier) throws IOException, + JsonPathNotFoundException { + XContentBuilder builder = jsonBuilder(); + request.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String json = Strings.toString(builder); + assertEquals(JsonDeserializer.getTextValue(json, ADCommonName.ID_JSON_KEY), requestSupplier.get()); + } + + public void testJsonRequestStopDetector() throws IOException, JsonPathNotFoundException { + StopDetectorRequest request = new StopDetectorRequest().adID("123"); + testJsonRequestTemplate(request, request::getAdID); + } + + public void testJsonRequestDeleteModel() throws IOException, JsonPathNotFoundException { + DeleteModelRequest request = new DeleteModelRequest("123"); + testJsonRequestTemplate(request, request::getAdID); + } + + private enum DetectorExecutionMode { + DELETE_MODEL_NORMAL, + DELETE_MODEL_FAILURE + } + + @SuppressWarnings("unchecked") + public void StopDetectorResponseTemplate(DetectorExecutionMode mode) throws Exception { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length >= 3 + ); + assertTrue(args[2] instanceof ActionListener); + + ActionListener listener = (ActionListener) args[2]; + + assertTrue(listener != null); + if (mode == DetectorExecutionMode.DELETE_MODEL_FAILURE) { + listener.onFailure(new OpenSearchException("")); + } else { + listener.onResponse(response); + } + + return null; + }).when(client).execute(eq(DeleteModelAction.INSTANCE), any(), any()); + + BulkByScrollResponse deleteByQueryResponse = mock(BulkByScrollResponse.class); + when(deleteByQueryResponse.getDeleted()).thenReturn(10L); + + String detectorID = "123"; + + DiscoveryNodeFilterer nodeFilter = mock(DiscoveryNodeFilterer.class); + StopDetectorTransportAction action = new StopDetectorTransportAction(transportService, nodeFilter, actionFilters, client); + + StopDetectorRequest request = new StopDetectorRequest().adID(detectorID); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(task, request, listener); + + StopDetectorResponse response = listener.actionGet(); + assertTrue(!response.success()); + + } + + public void testNormalResponse() throws Exception { + StopDetectorResponseTemplate(DetectorExecutionMode.DELETE_MODEL_NORMAL); + } + + public void testFailureResponse() throws Exception { + StopDetectorResponseTemplate(DetectorExecutionMode.DELETE_MODEL_FAILURE); + } + +} diff --git a/src/test/java/org/opensearch/ad/transport/EntityProfileTests.java b/src/test/java/org/opensearch/ad/transport/EntityProfileTests.java index b76eedc80..4a1cfc718 100644 --- a/src/test/java/org/opensearch/ad/transport/EntityProfileTests.java +++ b/src/test/java/org/opensearch/ad/transport/EntityProfileTests.java @@ -42,9 +42,9 @@ import org.opensearch.ad.model.ModelProfileOnNode; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; import org.opensearch.common.transport.TransportAddress; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.tasks.Task; import org.opensearch.timeseries.AbstractTimeSeriesTest; diff --git a/src/test/java/org/opensearch/ad/transport/EntityProfileTests.java-e b/src/test/java/org/opensearch/ad/transport/EntityProfileTests.java-e new file mode 100644 index 000000000..31bbc3950 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/EntityProfileTests.java-e @@ -0,0 +1,384 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.opensearch.Version; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.common.exception.JsonPathNotFoundException; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.EntityProfileName; +import org.opensearch.ad.model.ModelProfile; +import org.opensearch.ad.model.ModelProfileOnNode; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.transport.TransportAddress; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.transport.ConnectTransportException; +import org.opensearch.transport.Transport; +import org.opensearch.transport.TransportException; +import org.opensearch.transport.TransportInterceptor; +import org.opensearch.transport.TransportRequest; +import org.opensearch.transport.TransportRequestOptions; +import org.opensearch.transport.TransportResponse; +import org.opensearch.transport.TransportResponseHandler; +import org.opensearch.transport.TransportService; + +import test.org.opensearch.ad.util.FakeNode; +import test.org.opensearch.ad.util.JsonDeserializer; + +public class EntityProfileTests extends AbstractTimeSeriesTest { + private String detectorId = "yecrdnUBqurvo9uKU_d8"; + private String entityValue = "app_0"; + private String nodeId = "abc"; + private Set state; + private Set all; + private Set model; + private HashRing hashRing; + private ActionFilters actionFilters; + private TransportService transportService; + private Settings settings; + private ClusterService clusterService; + private CacheProvider cacheProvider; + private EntityProfileTransportAction action; + private Task task; + private PlainActionFuture future; + private TransportAddress transportAddress1; + private long updates; + private EntityProfileRequest request; + private String modelId; + private long lastActiveTimestamp = 1603989830158L; + private long modelSize = 712480L; + private boolean isActive = Boolean.TRUE; + private TransportInterceptor normalTransportInterceptor, failureTransportInterceptor; + private String categoryName = "field"; + private Entity entity; + + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(EntityProfileTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + + @Override + public void setUp() throws Exception { + super.setUp(); + state = new HashSet(); + state.add(EntityProfileName.STATE); + + all = new HashSet(); + all.add(EntityProfileName.INIT_PROGRESS); + all.add(EntityProfileName.ENTITY_INFO); + all.add(EntityProfileName.MODELS); + + model = new HashSet(); + model.add(EntityProfileName.MODELS); + + hashRing = mock(HashRing.class); + actionFilters = mock(ActionFilters.class); + transportService = new TransportService( + Settings.EMPTY, + mock(Transport.class), + null, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + x -> null, + null, + Collections.emptySet() + ); + settings = Settings.EMPTY; + + modelId = "yecrdnUBqurvo9uKU_d8_entity_app_0"; + + clusterService = mock(ClusterService.class); + + cacheProvider = mock(CacheProvider.class); + EntityCache cache = mock(EntityCache.class); + updates = 1L; + when(cache.getTotalUpdates(anyString(), anyString())).thenReturn(updates); + when(cache.isActive(anyString(), anyString())).thenReturn(isActive); + when(cache.getLastActiveMs(anyString(), anyString())).thenReturn(lastActiveTimestamp); + Map modelSizeMap = new HashMap<>(); + modelSizeMap.put(modelId, modelSize); + when(cache.getModelSize(anyString())).thenReturn(modelSizeMap); + when(cacheProvider.get()).thenReturn(cache); + + action = new EntityProfileTransportAction(actionFilters, transportService, settings, hashRing, clusterService, cacheProvider); + + future = new PlainActionFuture<>(); + transportAddress1 = new TransportAddress(new InetSocketAddress(InetAddress.getByName("1.2.3.4"), 9300)); + + entity = Entity.createSingleAttributeEntity(categoryName, entityValue); + + request = new EntityProfileRequest(detectorId, entity, state); + + normalTransportInterceptor = new TransportInterceptor() { + @Override + public AsyncSender interceptSender(AsyncSender sender) { + return new AsyncSender() { + @Override + public void sendRequest( + Transport.Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + if (EntityProfileAction.NAME.equals(action)) { + sender.sendRequest(connection, action, request, options, entityProfileHandler(handler)); + } else { + sender.sendRequest(connection, action, request, options, handler); + } + } + }; + } + }; + + failureTransportInterceptor = new TransportInterceptor() { + @Override + public AsyncSender interceptSender(AsyncSender sender) { + return new AsyncSender() { + @Override + public void sendRequest( + Transport.Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + if (EntityProfileAction.NAME.equals(action)) { + sender.sendRequest(connection, action, request, options, entityFailureProfileandler(handler)); + } else { + sender.sendRequest(connection, action, request, options, handler); + } + } + }; + } + }; + } + + private TransportResponseHandler entityProfileHandler(TransportResponseHandler handler) { + return new TransportResponseHandler() { + @Override + public T read(StreamInput in) throws IOException { + return handler.read(in); + } + + @Override + @SuppressWarnings("unchecked") + public void handleResponse(T response) { + handler.handleResponse((T) new EntityProfileResponse.Builder().setTotalUpdates(updates).build()); + } + + @Override + public void handleException(TransportException exp) { + handler.handleException(exp); + } + + @Override + public String executor() { + return handler.executor(); + } + }; + } + + private TransportResponseHandler entityFailureProfileandler(TransportResponseHandler handler) { + return new TransportResponseHandler() { + @Override + public T read(StreamInput in) throws IOException { + return handler.read(in); + } + + @Override + public void handleResponse(T response) { + handler + .handleException( + new ConnectTransportException( + new DiscoveryNode(nodeId, transportAddress1, Version.CURRENT.minimumCompatibilityVersion()), + EntityProfileAction.NAME + ) + ); + } + + @Override + public void handleException(TransportException exp) { + handler.handleException(exp); + } + + @Override + public String executor() { + return handler.executor(); + } + }; + } + + private void registerHandler(FakeNode node) { + new EntityProfileTransportAction( + new ActionFilters(Collections.emptySet()), + node.transportService, + Settings.EMPTY, + hashRing, + node.clusterService, + cacheProvider + ); + } + + public void testInvalidRequest() { + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(anyString())).thenReturn(Optional.empty()); + action.doExecute(task, request, future); + + assertException(future, TimeSeriesException.class, EntityProfileTransportAction.NO_NODE_FOUND_MSG); + } + + public void testLocalNodeHit() { + DiscoveryNode localNode = new DiscoveryNode(nodeId, transportAddress1, Version.CURRENT.minimumCompatibilityVersion()); + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(anyString())).thenReturn(Optional.of(localNode)); + when(clusterService.localNode()).thenReturn(localNode); + + action.doExecute(task, request, future); + EntityProfileResponse response = future.actionGet(20_000); + assertEquals(updates, response.getTotalUpdates()); + } + + public void testAllHit() { + DiscoveryNode localNode = new DiscoveryNode(nodeId, transportAddress1, Version.CURRENT.minimumCompatibilityVersion()); + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(anyString())).thenReturn(Optional.of(localNode)); + when(clusterService.localNode()).thenReturn(localNode); + + request = new EntityProfileRequest(detectorId, entity, all); + action.doExecute(task, request, future); + + EntityProfileResponse expectedResponse = new EntityProfileResponse(isActive, lastActiveTimestamp, updates, null); + EntityProfileResponse response = future.actionGet(20_000); + assertEquals(expectedResponse, response); + } + + public void testGetRemoteUpdateResponse() { + setupTestNodes(normalTransportInterceptor); + try { + TransportService realTransportService = testNodes[0].transportService; + clusterService = testNodes[0].clusterService; + + action = new EntityProfileTransportAction( + actionFilters, + realTransportService, + settings, + hashRing, + clusterService, + cacheProvider + ); + + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + .thenReturn(Optional.of(testNodes[1].discoveryNode())); + registerHandler(testNodes[1]); + + action.doExecute(null, request, future); + + EntityProfileResponse expectedResponse = new EntityProfileResponse(null, -1L, updates, null); + + EntityProfileResponse response = future.actionGet(10_000); + assertEquals(expectedResponse, response); + } finally { + tearDownTestNodes(); + } + } + + public void testGetRemoteFailureResponse() { + setupTestNodes(failureTransportInterceptor); + try { + TransportService realTransportService = testNodes[0].transportService; + clusterService = testNodes[0].clusterService; + + action = new EntityProfileTransportAction( + actionFilters, + realTransportService, + settings, + hashRing, + clusterService, + cacheProvider + ); + + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + .thenReturn(Optional.of(testNodes[1].discoveryNode())); + registerHandler(testNodes[1]); + + action.doExecute(null, request, future); + + expectThrows(ConnectTransportException.class, () -> future.actionGet()); + } finally { + tearDownTestNodes(); + } + } + + public void testResponseToXContent() throws IOException, JsonPathNotFoundException { + long lastActiveTimestamp = 10L; + EntityProfileResponse.Builder builder = new EntityProfileResponse.Builder(); + builder.setLastActiveMs(lastActiveTimestamp).build(); + builder.setModelProfile(new ModelProfileOnNode(nodeId, new ModelProfile(modelId, entity, modelSize))); + EntityProfileResponse response = builder.build(); + String json = TestHelpers.xContentBuilderToString(response.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + assertEquals(lastActiveTimestamp, JsonDeserializer.getLongValue(json, EntityProfileResponse.LAST_ACTIVE_TS)); + assertEquals(modelSize, JsonDeserializer.getChildNode(json, ADCommonName.MODEL, CommonName.MODEL_SIZE_IN_BYTES).getAsLong()); + } + + public void testResponseHashCodeEquals() { + EntityProfileResponse.Builder builder = new EntityProfileResponse.Builder(); + builder.setLastActiveMs(lastActiveTimestamp).build(); + ModelProfileOnNode model = new ModelProfileOnNode(nodeId, new ModelProfile(modelId, entity, modelSize)); + builder.setModelProfile(model); + EntityProfileResponse response = builder.build(); + + HashSet set = new HashSet<>(); + assertTrue(false == set.contains(response)); + set.add(response); + assertTrue(set.contains(response)); + } + + public void testEntityProfileName() { + assertEquals("state", EntityProfileName.getName(ADCommonName.STATE).getName()); + assertEquals("models", EntityProfileName.getName(ADCommonName.MODELS).getName()); + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> EntityProfileName.getName("abc")); + assertEquals(exception.getMessage(), ADCommonMessages.UNSUPPORTED_PROFILE_TYPE); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java-e b/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java-e new file mode 100644 index 000000000..30177220b --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java-e @@ -0,0 +1,416 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.startsWith; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; + +import java.io.IOException; +import java.time.Clock; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.apache.commons.lang3.tuple.Pair; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.ad.AnomalyDetectorJobRunnerTests; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.common.exception.JsonPathNotFoundException; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.CheckpointDao; +import org.opensearch.ad.ml.EntityColdStarter; +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.ratelimit.CheckpointReadWorker; +import org.opensearch.ad.ratelimit.ColdEntityWorker; +import org.opensearch.ad.ratelimit.EntityColdStartWorker; +import org.opensearch.ad.ratelimit.ResultWriteWorker; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.stats.ADStat; +import org.opensearch.ad.stats.ADStats; +import org.opensearch.ad.stats.suppliers.CounterSupplier; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Strings; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.transport.TransportService; + +import test.org.opensearch.ad.util.JsonDeserializer; +import test.org.opensearch.ad.util.MLUtil; +import test.org.opensearch.ad.util.RandomModelStateConfig; + +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; + +public class EntityResultTransportActionTests extends AbstractTimeSeriesTest { + EntityResultTransportAction entityResult; + ActionFilters actionFilters; + TransportService transportService; + ModelManager manager; + ADCircuitBreakerService adCircuitBreakerService; + CheckpointDao checkpointDao; + CacheProvider provider; + EntityCache entityCache; + NodeStateManager stateManager; + Settings settings; + Clock clock; + EntityResultRequest request; + String detectorId; + long timeoutMs; + AnomalyDetector detector; + String cacheMissEntity; + String cacheHitEntity; + Entity cacheHitEntityObj; + Entity cacheMissEntityObj; + long start; + long end; + Map entities; + double[] cacheMissData; + double[] cacheHitData; + String tooLongEntity; + double[] tooLongData; + ResultWriteWorker resultWriteQueue; + CheckpointReadWorker checkpointReadQueue; + int minSamples; + Instant now; + EntityColdStarter coldStarter; + ColdEntityWorker coldEntityQueue; + EntityColdStartWorker entityColdStartQueue; + ADIndexManagement indexUtil; + ClusterService clusterService; + ADStats adStats; + + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(AnomalyDetectorJobRunnerTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + + @SuppressWarnings("unchecked") + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + actionFilters = mock(ActionFilters.class); + transportService = mock(TransportService.class); + + adCircuitBreakerService = mock(ADCircuitBreakerService.class); + when(adCircuitBreakerService.isOpen()).thenReturn(false); + + checkpointDao = mock(CheckpointDao.class); + + detectorId = "123"; + entities = new HashMap<>(); + + start = 10L; + end = 20L; + request = new EntityResultRequest(detectorId, entities, start, end); + + clock = mock(Clock.class); + now = Instant.now(); + when(clock.instant()).thenReturn(now); + + settings = Settings + .builder() + .put(AnomalyDetectorSettings.COOLDOWN_MINUTES.getKey(), TimeValue.timeValueMinutes(5)) + .put(AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ.getKey(), TimeValue.timeValueHours(12)) + .build(); + + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ))) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + manager = new ModelManager( + null, + clock, + 0, + 0, + 0, + 0, + 0, + 0, + null, + AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, + mock(EntityColdStarter.class), + null, + null, + settings, + clusterService + ); + + provider = mock(CacheProvider.class); + entityCache = mock(EntityCache.class); + when(provider.get()).thenReturn(entityCache); + + String field = "a"; + detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); + stateManager = mock(NodeStateManager.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(stateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + + cacheMissEntity = "0.0.0.1"; + cacheMissData = new double[] { 0.1 }; + cacheHitEntity = "0.0.0.2"; + cacheHitData = new double[] { 0.2 }; + cacheMissEntityObj = Entity.createSingleAttributeEntity(detector.getCategoryFields().get(0), cacheMissEntity); + entities.put(cacheMissEntityObj, cacheMissData); + cacheHitEntityObj = Entity.createSingleAttributeEntity(detector.getCategoryFields().get(0), cacheHitEntity); + entities.put(cacheHitEntityObj, cacheHitData); + tooLongEntity = randomAlphaOfLength(AnomalyDetectorSettings.MAX_ENTITY_LENGTH + 1); + tooLongData = new double[] { 0.3 }; + entities.put(Entity.createSingleAttributeEntity(detector.getCategoryFields().get(0), tooLongEntity), tooLongData); + + ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + when(entityCache.get(eq(cacheMissEntityObj.getModelId(detectorId).get()), any())).thenReturn(null); + when(entityCache.get(eq(cacheHitEntityObj.getModelId(detectorId).get()), any())).thenReturn(state); + + List coldEntities = new ArrayList<>(); + coldEntities.add(cacheMissEntityObj); + when(entityCache.selectUpdateCandidate(any(), anyString(), any())).thenReturn(Pair.of(new ArrayList<>(), coldEntities)); + + indexUtil = mock(ADIndexManagement.class); + when(indexUtil.getSchemaVersion(any())).thenReturn(CommonValue.NO_SCHEMA_VERSION); + + resultWriteQueue = mock(ResultWriteWorker.class); + checkpointReadQueue = mock(CheckpointReadWorker.class); + + minSamples = 1; + + coldStarter = mock(EntityColdStarter.class); + + doAnswer(invocation -> { + ModelState modelState = invocation.getArgument(0); + modelState.getModel().clear(); + return null; + }).when(coldStarter).trainModelFromExistingSamples(any(), anyInt()); + + coldEntityQueue = mock(ColdEntityWorker.class); + entityColdStartQueue = mock(EntityColdStartWorker.class); + + Map> statsMap = new HashMap>() { + { + put(StatNames.MODEL_CORRUTPION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + } + }; + + adStats = new ADStats(statsMap); + + entityResult = new EntityResultTransportAction( + actionFilters, + transportService, + manager, + adCircuitBreakerService, + provider, + stateManager, + indexUtil, + resultWriteQueue, + checkpointReadQueue, + coldEntityQueue, + threadPool, + entityColdStartQueue, + adStats + ); + + // timeout in 60 seconds + timeoutMs = 60000L; + } + + public void testCircuitBreakerOpen() { + when(adCircuitBreakerService.isOpen()).thenReturn(true); + PlainActionFuture future = PlainActionFuture.newFuture(); + + entityResult.doExecute(null, request, future); + + expectThrows(LimitExceededException.class, () -> future.actionGet(timeoutMs)); + } + + public void testNormal() { + PlainActionFuture future = PlainActionFuture.newFuture(); + + entityResult.doExecute(null, request, future); + + future.actionGet(timeoutMs); + + verify(resultWriteQueue, times(1)).put(any()); + } + + // test get detector failure + @SuppressWarnings("unchecked") + public void testFailtoGetDetector() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.empty()); + return null; + }).when(stateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + + PlainActionFuture future = PlainActionFuture.newFuture(); + + entityResult.doExecute(null, request, future); + + expectThrows(EndRunException.class, () -> future.actionGet(timeoutMs)); + } + + // test rcf score is 0 + public void testNoResultsToSave() { + ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(false).build()); + when(entityCache.get(eq(cacheHitEntityObj.getModelId(detectorId).get()), any())).thenReturn(state); + + PlainActionFuture future = PlainActionFuture.newFuture(); + + entityResult.doExecute(null, request, future); + + future.actionGet(timeoutMs); + + verify(resultWriteQueue, never()).put(any()); + } + + public void testValidRequest() { + ActionRequestValidationException e = request.validate(); + assertThat(e, equalTo(null)); + } + + public void testEmptyId() { + request = new EntityResultRequest("", entities, start, end); + ActionRequestValidationException e = request.validate(); + assertThat(e.validationErrors(), hasItem(ADCommonMessages.AD_ID_MISSING_MSG)); + } + + public void testReverseTime() { + request = new EntityResultRequest(detectorId, entities, end, start); + ActionRequestValidationException e = request.validate(); + assertThat(e.validationErrors(), hasItem(startsWith(CommonMessages.INVALID_TIMESTAMP_ERR_MSG))); + } + + public void testNegativeTime() { + request = new EntityResultRequest(detectorId, entities, start, -end); + ActionRequestValidationException e = request.validate(); + assertThat(e.validationErrors(), hasItem(startsWith(CommonMessages.INVALID_TIMESTAMP_ERR_MSG))); + } + + public void testJsonResponse() throws IOException, JsonPathNotFoundException { + XContentBuilder builder = jsonBuilder(); + request.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String json = Strings.toString(builder); + assertEquals(JsonDeserializer.getTextValue(json, ADCommonName.ID_JSON_KEY), detectorId); + assertEquals(JsonDeserializer.getLongValue(json, CommonName.START_JSON_KEY), start); + assertEquals(JsonDeserializer.getLongValue(json, CommonName.END_JSON_KEY), end); + JsonArray array = JsonDeserializer.getArrayValue(json, CommonName.ENTITIES_JSON_KEY); + assertEquals(3, array.size()); + for (int i = 0; i < 3; i++) { + JsonElement element = array.get(i); + JsonElement entity = JsonDeserializer.getChildNode(element, CommonName.ENTITY_KEY); + JsonArray entityArray = entity.getAsJsonArray(); + assertEquals(1, entityArray.size()); + + JsonElement attribute = entityArray.get(0); + String entityValue = JsonDeserializer.getChildNode(attribute, Entity.ATTRIBUTE_VALUE_FIELD).getAsString(); + + double value = JsonDeserializer.getChildNode(element, CommonName.VALUE_JSON_KEY).getAsJsonArray().get(0).getAsDouble(); + + if (entityValue.equals(cacheMissEntity)) { + assertEquals(0, Double.compare(cacheMissData[0], value)); + } else if (entityValue.equals(cacheHitEntity)) { + assertEquals(0, Double.compare(cacheHitData[0], value)); + } else { + assertEquals(0, Double.compare(tooLongData[0], value)); + } + } + } + + public void testFailToScore() { + ModelManager spyModelManager = spy(manager); + doThrow(new IllegalArgumentException()).when(spyModelManager).getAnomalyResultForEntity(any(), any(), anyString(), any(), anyInt()); + entityResult = new EntityResultTransportAction( + actionFilters, + transportService, + spyModelManager, + adCircuitBreakerService, + provider, + stateManager, + indexUtil, + resultWriteQueue, + checkpointReadQueue, + coldEntityQueue, + threadPool, + entityColdStartQueue, + adStats + ); + + PlainActionFuture future = PlainActionFuture.newFuture(); + + entityResult.doExecute(null, request, future); + + future.actionGet(timeoutMs); + + verify(resultWriteQueue, never()).put(any()); + verify(entityCache, times(1)).removeEntityModel(anyString(), anyString()); + verify(entityColdStartQueue, times(1)).put(any()); + Object val = adStats.getStat(StatNames.MODEL_CORRUTPION_COUNT.getName()).getValue(); + assertEquals(1L, ((Long) val).longValue()); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ForwardADTaskRequestTests.java b/src/test/java/org/opensearch/ad/transport/ForwardADTaskRequestTests.java index 94cbdee06..633a9a4fe 100644 --- a/src/test/java/org/opensearch/ad/transport/ForwardADTaskRequestTests.java +++ b/src/test/java/org/opensearch/ad/transport/ForwardADTaskRequestTests.java @@ -25,18 +25,18 @@ import org.opensearch.Version; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.mock.transport.MockADTaskAction_1_0; import org.opensearch.ad.mock.transport.MockForwardADTaskRequest_1_0; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; -import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.plugins.Plugin; import org.opensearch.test.InternalSettingsPlugin; import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.common.exception.VersionException; import org.opensearch.timeseries.settings.TimeSeriesSettings; @@ -46,7 +46,7 @@ public class ForwardADTaskRequestTests extends OpenSearchSingleNodeTestCase { @Override protected Collection> getPlugins() { - return pluginList(InternalSettingsPlugin.class, AnomalyDetectorPlugin.class); + return pluginList(InternalSettingsPlugin.class, TimeSeriesAnalyticsPlugin.class); } @Override diff --git a/src/test/java/org/opensearch/ad/transport/ForwardADTaskRequestTests.java-e b/src/test/java/org/opensearch/ad/transport/ForwardADTaskRequestTests.java-e new file mode 100644 index 000000000..633a9a4fe --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ForwardADTaskRequestTests.java-e @@ -0,0 +1,152 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.model.ADTaskAction.CLEAN_CACHE; +import static org.opensearch.ad.model.ADTaskAction.CLEAN_STALE_RUNNING_ENTITIES; +import static org.opensearch.ad.model.ADTaskAction.START; +import static org.opensearch.timeseries.TestHelpers.randomIntervalTimeConfiguration; +import static org.opensearch.timeseries.TestHelpers.randomQuery; +import static org.opensearch.timeseries.TestHelpers.randomUser; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collection; +import java.util.Locale; + +import org.opensearch.Version; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.mock.transport.MockADTaskAction_1_0; +import org.opensearch.ad.mock.transport.MockForwardADTaskRequest_1_0; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.plugins.Plugin; +import org.opensearch.test.InternalSettingsPlugin; +import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.common.exception.VersionException; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + +import com.google.common.collect.ImmutableList; + +public class ForwardADTaskRequestTests extends OpenSearchSingleNodeTestCase { + + @Override + protected Collection> getPlugins() { + return pluginList(InternalSettingsPlugin.class, TimeSeriesAnalyticsPlugin.class); + } + + @Override + protected NamedWriteableRegistry writableRegistry() { + return getInstanceFromNode(NamedWriteableRegistry.class); + } + + public void testNullVersion() throws IOException { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableList.of()); + expectThrows(VersionException.class, () -> new ForwardADTaskRequest(detector, null, null, null, null, null)); + } + + public void testNullDetectorIdAndTaskAction() throws IOException { + AnomalyDetector detector = new AnomalyDetector( + null, + randomLong(), + randomAlphaOfLength(20), + randomAlphaOfLength(30), + randomAlphaOfLength(5), + ImmutableList.of(randomAlphaOfLength(10).toLowerCase(Locale.ROOT)), + ImmutableList.of(), + randomQuery(), + randomIntervalTimeConfiguration(), + randomIntervalTimeConfiguration(), + randomIntBetween(1, TimeSeriesSettings.MAX_SHINGLE_SIZE), + null, + randomInt(), + Instant.now(), + null, + randomUser(), + null, + TestHelpers.randomImputationOption() + ); + ForwardADTaskRequest request = new ForwardADTaskRequest(detector, null, null, null, null, Version.V_2_1_0); + ActionRequestValidationException validate = request.validate(); + assertEquals("Validation Failed: 1: AD ID is missing;2: AD task action is missing;", validate.getMessage()); + } + + public void testEmptyStaleEntities() throws IOException { + ADTask adTask = TestHelpers.randomAdTask(); + ForwardADTaskRequest request = new ForwardADTaskRequest(adTask, CLEAN_STALE_RUNNING_ENTITIES, null); + ActionRequestValidationException validate = request.validate(); + assertEquals("Validation Failed: 1: Empty stale running entities;", validate.getMessage()); + } + + public void testSerializeRequest() throws IOException { + ADTask adTask = TestHelpers.randomAdTask(); + ForwardADTaskRequest request = new ForwardADTaskRequest(adTask, CLEAN_STALE_RUNNING_ENTITIES, null); + BytesStreamOutput output = new BytesStreamOutput(); + request.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + ForwardADTaskRequest parsedInput = new ForwardADTaskRequest(input); + assertEquals(request, parsedInput); + } + + public void testParseRequestFromOldNodeWithNewCode() throws IOException { + ADTask adTask = TestHelpers.randomAdTask(); + MockForwardADTaskRequest_1_0 oldRequest = new MockForwardADTaskRequest_1_0( + adTask.getDetector(), + adTask.getUser(), + MockADTaskAction_1_0.START + ); + BytesStreamOutput output = new BytesStreamOutput(); + oldRequest.writeTo(output); + + // Parse old forward AD task request of 1.0, will reject it directly, + // so if old node is coordinating node, it can't use new node as worker node to run task. + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + expectThrows(VersionException.class, () -> new ForwardADTaskRequest(input)); + } + + public void testParseRequestFromNewNodeWithOldCode_StartAction() throws IOException { + ADTask adTask = TestHelpers.randomAdTask(); + ForwardADTaskRequest request = new ForwardADTaskRequest(adTask, START, null); + BytesStreamOutput output = new BytesStreamOutput(); + request.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + MockForwardADTaskRequest_1_0 parsedInput = new MockForwardADTaskRequest_1_0(input); + // START action should be parsed as START action on old node + // If coordinating node is new node, it will just use new node as worker node to run task. + // So it's impossible that new node will send START action to old node. Add this test case + // just to show the request parsing logic works. + assertEquals(MockADTaskAction_1_0.START, parsedInput.getAdTaskAction()); + assertEquals(request.getDetector(), parsedInput.getDetector()); + } + + public void testParseRequestFromNewNodeWithOldCode_CleanCacheAction() throws IOException { + ADTask adTask = TestHelpers.randomAdTask(); + ForwardADTaskRequest request = new ForwardADTaskRequest(adTask, CLEAN_CACHE, null); + BytesStreamOutput output = new BytesStreamOutput(); + request.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + MockForwardADTaskRequest_1_0 parsedInput = new MockForwardADTaskRequest_1_0(input); + // CLEAN_CACHE action should be parsed as STOP action on old node + // In old version on or before AD1.0, worker node will send STOP action to clean cache + // on coordinating node when task done on worker node. + // In mixed cluster, new node will reject START action if it's from old node. + // So no new node will run as worker node for old coordinating node. + // Add this test case just to show the request task action parsing logic works. + assertEquals(MockADTaskAction_1_0.STOP, parsedInput.getAdTaskAction()); + assertEquals(request.getDetector(), parsedInput.getDetector()); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ForwardADTaskTests.java b/src/test/java/org/opensearch/ad/transport/ForwardADTaskTests.java index 29b706dcc..982cf9262 100644 --- a/src/test/java/org/opensearch/ad/transport/ForwardADTaskTests.java +++ b/src/test/java/org/opensearch/ad/transport/ForwardADTaskTests.java @@ -18,16 +18,16 @@ import org.junit.Before; import org.opensearch.Version; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.model.ADTaskAction; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; -import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.plugins.Plugin; import org.opensearch.test.InternalSettingsPlugin; import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import com.google.common.collect.ImmutableMap; @@ -43,7 +43,7 @@ public void setUp() throws Exception { @Override protected Collection> getPlugins() { - return pluginList(InternalSettingsPlugin.class, AnomalyDetectorPlugin.class); + return pluginList(InternalSettingsPlugin.class, TimeSeriesAnalyticsPlugin.class); } @Override diff --git a/src/test/java/org/opensearch/ad/transport/ForwardADTaskTests.java-e b/src/test/java/org/opensearch/ad/transport/ForwardADTaskTests.java-e new file mode 100644 index 000000000..982cf9262 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ForwardADTaskTests.java-e @@ -0,0 +1,103 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collection; + +import org.junit.Before; +import org.opensearch.Version; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.model.ADTaskAction; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.plugins.Plugin; +import org.opensearch.test.InternalSettingsPlugin; +import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +import com.google.common.collect.ImmutableMap; + +public class ForwardADTaskTests extends OpenSearchSingleNodeTestCase { + private Version testVersion; + + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + testVersion = Version.fromString("1.1.0"); + } + + @Override + protected Collection> getPlugins() { + return pluginList(InternalSettingsPlugin.class, TimeSeriesAnalyticsPlugin.class); + } + + @Override + protected NamedWriteableRegistry writableRegistry() { + return getInstanceFromNode(NamedWriteableRegistry.class); + } + + public void testForwardADTaskRequest() throws IOException { + ForwardADTaskRequest request = new ForwardADTaskRequest( + TestHelpers.randomAnomalyDetector(ImmutableMap.of(), Instant.now()), + TestHelpers.randomDetectionDateRange(), + TestHelpers.randomUser(), + ADTaskAction.START, + randomInt(), + testVersion + ); + testForwardADTaskRequest(request); + } + + public void testForwardADTaskRequestWithoutUser() throws IOException { + ForwardADTaskRequest request = new ForwardADTaskRequest( + TestHelpers.randomAnomalyDetector(ImmutableMap.of(), Instant.now()), + TestHelpers.randomDetectionDateRange(), + null, + ADTaskAction.START, + randomInt(), + testVersion + ); + testForwardADTaskRequest(request); + } + + public void testInvalidForwardADTaskRequest() { + ForwardADTaskRequest request = new ForwardADTaskRequest( + null, + TestHelpers.randomDetectionDateRange(), + TestHelpers.randomUser(), + ADTaskAction.START, + randomInt(), + testVersion + ); + + ActionRequestValidationException exception = request.validate(); + assertTrue(exception.getMessage().contains(ADCommonMessages.DETECTOR_MISSING)); + } + + private void testForwardADTaskRequest(ForwardADTaskRequest request) throws IOException { + BytesStreamOutput output = new BytesStreamOutput(); + request.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + ForwardADTaskRequest parsedRequest = new ForwardADTaskRequest(input); + if (request.getUser() != null) { + assertTrue(request.getUser().equals(parsedRequest.getUser())); + } else { + assertNull(parsedRequest.getUser()); + } + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ForwardADTaskTransportActionTests.java-e b/src/test/java/org/opensearch/ad/transport/ForwardADTaskTransportActionTests.java-e new file mode 100644 index 000000000..f2da82c36 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ForwardADTaskTransportActionTests.java-e @@ -0,0 +1,261 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ad.model.ADTaskAction.CANCEL; +import static org.opensearch.ad.model.ADTaskAction.CHECK_AVAILABLE_TASK_SLOTS; +import static org.opensearch.ad.model.ADTaskAction.CLEAN_STALE_RUNNING_ENTITIES; +import static org.opensearch.ad.model.ADTaskAction.NEXT_ENTITY; +import static org.opensearch.ad.model.ADTaskAction.PUSH_BACK_ENTITY; +import static org.opensearch.ad.model.ADTaskAction.SCALE_ENTITY_TASK_SLOTS; + +import java.io.IOException; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.ADUnitTestCase; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.task.ADTaskCacheManager; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableList; + +public class ForwardADTaskTransportActionTests extends ADUnitTestCase { + private ActionFilters actionFilters; + private TransportService transportService; + private ADTaskManager adTaskManager; + private ADTaskCacheManager adTaskCacheManager; + private FeatureManager featureManager; + private NodeStateManager stateManager; + private ForwardADTaskTransportAction forwardADTaskTransportAction; + private Task task; + private ActionListener listener; + + @SuppressWarnings("unchecked") + @Override + public void setUp() throws Exception { + super.setUp(); + actionFilters = mock(ActionFilters.class); + transportService = mock(TransportService.class); + adTaskManager = mock(ADTaskManager.class); + adTaskCacheManager = mock(ADTaskCacheManager.class); + featureManager = mock(FeatureManager.class); + stateManager = mock(NodeStateManager.class); + forwardADTaskTransportAction = new ForwardADTaskTransportAction( + actionFilters, + transportService, + adTaskManager, + adTaskCacheManager, + featureManager, + stateManager + ); + + task = mock(Task.class); + listener = mock(ActionListener.class); + } + + public void testCheckAvailableTaskSlots() throws IOException { + ADTask adTask = TestHelpers.randomAdTask(); + ForwardADTaskRequest request = new ForwardADTaskRequest(adTask, CHECK_AVAILABLE_TASK_SLOTS); + forwardADTaskTransportAction.doExecute(task, request, listener); + verify(adTaskManager, times(1)).checkTaskSlots(any(), any(), any(), any(), any(), any(), any()); + } + + public void testNextEntityTaskForSingleEntityDetector() throws IOException { + when(adTaskCacheManager.hasEntity(anyString())).thenReturn(false); + + ADTask adTask = TestHelpers.randomAdTask(ADTaskType.HISTORICAL_SINGLE_ENTITY); + ForwardADTaskRequest request = new ForwardADTaskRequest(adTask, NEXT_ENTITY); + forwardADTaskTransportAction.doExecute(task, request, listener); + verify(listener, times(1)).onFailure(any()); + } + + public void testNextEntityTaskWithNoPendingEntity() throws IOException { + when(adTaskCacheManager.hasEntity(anyString())).thenReturn(false); + + ADTask adTask = TestHelpers.randomAdTask(ADTaskType.HISTORICAL_HC_DETECTOR); + ForwardADTaskRequest request = new ForwardADTaskRequest(adTask, NEXT_ENTITY); + forwardADTaskTransportAction.doExecute(task, request, listener); + verify(adTaskCacheManager, times(1)).setDetectorTaskSlots(anyString(), eq(0)); + verify(adTaskManager, times(1)).setHCDetectorTaskDone(any(), any(), any()); + } + + public void testNextEntityTaskWithPendingEntity() throws IOException { + when(adTaskCacheManager.hasEntity(anyString())).thenReturn(true); + + ADTask adTask = TestHelpers.randomAdTask(ADTaskType.HISTORICAL_HC_DETECTOR); + ForwardADTaskRequest request = new ForwardADTaskRequest(adTask, NEXT_ENTITY); + forwardADTaskTransportAction.doExecute(task, request, listener); + verify(adTaskManager, times(1)).runNextEntityForHCADHistorical(any(), any(), any()); + verify(adTaskManager, times(1)).updateADHCDetectorTask(any(), any(), any()); + } + + public void testPushBackEntityForSingleEntityDetector() throws IOException { + ADTask adTask = TestHelpers.randomAdTask(ADTaskType.HISTORICAL_SINGLE_ENTITY); + ForwardADTaskRequest request = new ForwardADTaskRequest(adTask, PUSH_BACK_ENTITY); + forwardADTaskTransportAction.doExecute(task, request, listener); + verify(listener, times(1)).onFailure(any()); + } + + public void testPushBackEntityForNonRetryableExceptionAndNoPendingEntity() throws IOException { + ADTask adTask = TestHelpers.randomAdTask(ADTaskType.HISTORICAL_HC_ENTITY); + when(adTaskManager.convertEntityToString(any())).thenReturn(randomAlphaOfLength(5)); + when(adTaskManager.isRetryableError(any())).thenReturn(false); + when(adTaskCacheManager.hasEntity(anyString())).thenReturn(false); + ForwardADTaskRequest request = new ForwardADTaskRequest(adTask, PUSH_BACK_ENTITY); + forwardADTaskTransportAction.doExecute(task, request, listener); + verify(adTaskCacheManager, times(1)).removeEntity(anyString(), anyString()); + verify(adTaskCacheManager, times(1)).setDetectorTaskSlots(anyString(), eq(0)); + verify(adTaskManager, times(1)).setHCDetectorTaskDone(any(), any(), any()); + } + + public void testPushBackEntityForNonRetryableExceptionAndPendingEntity() throws IOException { + ADTask adTask = TestHelpers.randomAdTask(ADTaskType.HISTORICAL_HC_ENTITY); + when(adTaskManager.convertEntityToString(any())).thenReturn(randomAlphaOfLength(5)); + when(adTaskManager.isRetryableError(any())).thenReturn(false); + when(adTaskCacheManager.hasEntity(anyString())).thenReturn(true); + when(adTaskCacheManager.scaleDownHCDetectorTaskSlots(anyString(), anyInt())).thenReturn(randomIntBetween(2, 10)).thenReturn(1); + + ForwardADTaskRequest request = new ForwardADTaskRequest(adTask, PUSH_BACK_ENTITY); + forwardADTaskTransportAction.doExecute(task, request, listener); + verify(adTaskCacheManager, times(1)).removeEntity(anyString(), anyString()); + verify(adTaskManager, times(0)).runNextEntityForHCADHistorical(any(), any(), any()); + + forwardADTaskTransportAction.doExecute(task, request, listener); + verify(adTaskCacheManager, times(2)).removeEntity(anyString(), anyString()); + verify(adTaskManager, times(1)).runNextEntityForHCADHistorical(any(), any(), any()); + } + + public void testPushBackEntityForRetryableExceptionAndNoPendingEntity() throws IOException { + ADTask adTask = TestHelpers.randomAdTask(ADTaskType.HISTORICAL_HC_ENTITY); + when(adTaskManager.convertEntityToString(any())).thenReturn(randomAlphaOfLength(5)); + when(adTaskManager.isRetryableError(any())).thenReturn(true); + when(adTaskCacheManager.exceedRetryLimit(any(), any())).thenReturn(false).thenReturn(true); + when(adTaskCacheManager.hasEntity(anyString())).thenReturn(false); + + ForwardADTaskRequest request = new ForwardADTaskRequest(adTask, PUSH_BACK_ENTITY); + forwardADTaskTransportAction.doExecute(task, request, listener); + verify(adTaskCacheManager, times(1)).pushBackEntity(anyString(), anyString(), anyString()); + verify(adTaskCacheManager, times(1)).setDetectorTaskSlots(anyString(), eq(0)); + verify(adTaskManager, times(1)).setHCDetectorTaskDone(any(), any(), any()); + + forwardADTaskTransportAction.doExecute(task, request, listener); + // will not push back entity task if exceed retry limit + verify(adTaskCacheManager, times(1)).pushBackEntity(anyString(), anyString(), anyString()); + verify(adTaskCacheManager, times(2)).setDetectorTaskSlots(anyString(), eq(0)); + verify(adTaskManager, times(2)).setHCDetectorTaskDone(any(), any(), any()); + } + + public void testPushBackEntityForRetryableExceptionAndPendingEntity() throws IOException { + ADTask adTask = TestHelpers.randomAdTask(ADTaskType.HISTORICAL_HC_ENTITY); + when(adTaskManager.convertEntityToString(any())).thenReturn(randomAlphaOfLength(5)); + when(adTaskManager.isRetryableError(any())).thenReturn(true); + when(adTaskCacheManager.hasEntity(anyString())).thenReturn(true); + when(adTaskCacheManager.scaleDownHCDetectorTaskSlots(anyString(), anyInt())).thenReturn(randomIntBetween(2, 10)).thenReturn(1); + + ForwardADTaskRequest request = new ForwardADTaskRequest(adTask, PUSH_BACK_ENTITY); + forwardADTaskTransportAction.doExecute(task, request, listener); + verify(adTaskCacheManager, times(1)).pushBackEntity(anyString(), anyString(), anyString()); + verify(adTaskCacheManager, times(0)).removeEntity(anyString(), anyString()); + verify(adTaskManager, times(0)).runNextEntityForHCADHistorical(any(), any(), any()); + + forwardADTaskTransportAction.doExecute(task, request, listener); + verify(adTaskCacheManager, times(2)).pushBackEntity(anyString(), anyString(), anyString()); + verify(adTaskCacheManager, times(0)).removeEntity(anyString(), anyString()); + verify(adTaskManager, times(1)).runNextEntityForHCADHistorical(any(), any(), any()); + } + + public void testScaleEntityTaskSlotsWithNoAvailableSlots() throws IOException { + ADTask adTask = TestHelpers.randomAdTask(ADTaskType.HISTORICAL_HC_ENTITY); + ForwardADTaskRequest request = new ForwardADTaskRequest(adTask, null, SCALE_ENTITY_TASK_SLOTS); + forwardADTaskTransportAction.doExecute(task, request, listener); + verify(adTaskCacheManager, never()).scaleUpDetectorTaskSlots(anyString(), anyInt()); + verify(listener, times(1)).onResponse(any()); + + request = new ForwardADTaskRequest(adTask, 0, SCALE_ENTITY_TASK_SLOTS); + forwardADTaskTransportAction.doExecute(task, request, listener); + verify(adTaskCacheManager, never()).scaleUpDetectorTaskSlots(anyString(), anyInt()); + verify(listener, times(2)).onResponse(any()); + } + + public void testScaleEntityTaskSlotsWithAvailableSlots() throws IOException { + ADTask adTask = TestHelpers.randomAdTask(ADTaskType.HISTORICAL_HC_ENTITY); + when(adTaskManager.detectorTaskSlotScaleDelta(anyString())).thenReturn(randomIntBetween(10, 20)).thenReturn(-1); + ForwardADTaskRequest request = new ForwardADTaskRequest(adTask, 1, SCALE_ENTITY_TASK_SLOTS); + + forwardADTaskTransportAction.doExecute(task, request, listener); + verify(adTaskCacheManager, times(1)).setAllowedRunningEntities(anyString(), anyInt()); + verify(adTaskCacheManager, times(1)).scaleUpDetectorTaskSlots(anyString(), anyInt()); + verify(listener, times(1)).onResponse(any()); + + forwardADTaskTransportAction.doExecute(task, request, listener); + verify(adTaskCacheManager, times(1)).setAllowedRunningEntities(anyString(), anyInt()); + verify(adTaskCacheManager, times(1)).scaleUpDetectorTaskSlots(anyString(), anyInt()); + verify(listener, times(2)).onResponse(any()); + } + + public void testCancelSingleEntityDetector() throws IOException { + ADTask adTask = TestHelpers.randomAdTask(ADTaskType.HISTORICAL_SINGLE_ENTITY); + ForwardADTaskRequest request = new ForwardADTaskRequest(adTask, CANCEL); + forwardADTaskTransportAction.doExecute(task, request, listener); + verify(listener, times(1)).onFailure(any()); + } + + public void testCancelHCDetector() throws IOException { + when(adTaskCacheManager.hasEntity(anyString())).thenReturn(true).thenReturn(false).thenReturn(true); + when(adTaskManager.convertEntityToString(any())).thenReturn(randomAlphaOfLength(5)); + + ADTask adTask = TestHelpers.randomAdTask(ADTaskType.HISTORICAL_HC_ENTITY); + ForwardADTaskRequest request = new ForwardADTaskRequest(adTask, CANCEL); + forwardADTaskTransportAction.doExecute(task, request, listener); + verify(adTaskCacheManager, times(1)).clearPendingEntities(anyString()); + verify(adTaskCacheManager, times(1)).removeRunningEntity(anyString(), anyString()); + verify(listener, times(1)).onResponse(any()); + + forwardADTaskTransportAction.doExecute(task, request, listener); + verify(adTaskCacheManager, times(2)).clearPendingEntities(anyString()); + verify(adTaskCacheManager, times(2)).removeRunningEntity(anyString(), anyString()); + verify(adTaskManager, times(1)).setHCDetectorTaskDone(any(), any(), any()); + verify(listener, times(2)).onResponse(any()); + + request = new ForwardADTaskRequest(TestHelpers.randomAdTask(ADTaskType.HISTORICAL_HC_DETECTOR), CANCEL); + forwardADTaskTransportAction.doExecute(task, request, listener); + verify(adTaskCacheManager, times(3)).clearPendingEntities(anyString()); + verify(adTaskCacheManager, times(3)).removeRunningEntity(anyString(), anyString()); + verify(adTaskManager, times(2)).setHCDetectorTaskDone(any(), any(), any()); + verify(listener, times(3)).onResponse(any()); + } + + public void testCleanStaleRunningEntities() throws IOException { + ADTask adTask = TestHelpers.randomAdTask(ADTaskType.HISTORICAL_HC_DETECTOR); + ImmutableList staleEntities = ImmutableList.of(randomAlphaOfLength(5)); + ForwardADTaskRequest request = new ForwardADTaskRequest(adTask, CLEAN_STALE_RUNNING_ENTITIES, staleEntities); + forwardADTaskTransportAction.doExecute(task, request, listener); + verify(adTaskManager, times(staleEntities.size())).removeStaleRunningEntity(any(), any(), any(), any()); + verify(listener, times(1)).onResponse(any()); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorActionTests.java b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorActionTests.java index 4292dc6ac..60144c63c 100644 --- a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorActionTests.java @@ -23,8 +23,8 @@ import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.DetectorProfile; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.rest.RestStatus; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.rest.RestStatus; import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorActionTests.java-e b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorActionTests.java-e new file mode 100644 index 000000000..60144c63c --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorActionTests.java-e @@ -0,0 +1,79 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mockito; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.model.DetectorProfile; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.rest.RestStatus; +import org.powermock.api.mockito.PowerMockito; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +@RunWith(PowerMockRunner.class) +@PrepareForTest(GetAnomalyDetectorResponse.class) +public class GetAnomalyDetectorActionTests { + @Before + public void setUp() throws Exception { + + } + + @Test + public void testGetRequest() throws IOException { + BytesStreamOutput out = new BytesStreamOutput(); + GetAnomalyDetectorRequest request = new GetAnomalyDetectorRequest("1234", 4321, false, false, "nonempty", "", false, null); + request.writeTo(out); + StreamInput input = out.bytes().streamInput(); + GetAnomalyDetectorRequest newRequest = new GetAnomalyDetectorRequest(input); + Assert.assertEquals(request.getDetectorID(), newRequest.getDetectorID()); + + } + + @Test + public void testGetResponse() throws Exception { + BytesStreamOutput out = new BytesStreamOutput(); + AnomalyDetector detector = Mockito.mock(AnomalyDetector.class); + AnomalyDetectorJob detectorJob = Mockito.mock(AnomalyDetectorJob.class); + Mockito.doNothing().when(detector).writeTo(out); + GetAnomalyDetectorResponse response = new GetAnomalyDetectorResponse( + 1234, + "4567", + 9876, + 2345, + detector, + detectorJob, + false, + Mockito.mock(ADTask.class), + Mockito.mock(ADTask.class), + false, + RestStatus.OK, + Mockito.mock(DetectorProfile.class), + null, + false + ); + response.writeTo(out); + StreamInput input = out.bytes().streamInput(); + PowerMockito.whenNew(AnomalyDetector.class).withAnyArguments().thenReturn(detector); + GetAnomalyDetectorResponse newResponse = new GetAnomalyDetectorResponse(input); + Assert.assertNotNull(newResponse); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorResponseTests.java b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorResponseTests.java index fa9ff1b8c..ace2c3c8c 100644 --- a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorResponseTests.java +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorResponseTests.java @@ -16,15 +16,15 @@ import java.time.temporal.ChronoUnit; import java.util.Collection; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; -import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.rest.RestStatus; import org.opensearch.plugins.Plugin; -import org.opensearch.rest.RestStatus; import org.opensearch.test.InternalSettingsPlugin; import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -33,7 +33,7 @@ public class GetAnomalyDetectorResponseTests extends OpenSearchSingleNodeTestCas @Override protected Collection> getPlugins() { - return pluginList(InternalSettingsPlugin.class, AnomalyDetectorPlugin.class); + return pluginList(InternalSettingsPlugin.class, TimeSeriesAnalyticsPlugin.class); } @Override diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorResponseTests.java-e b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorResponseTests.java-e new file mode 100644 index 000000000..240a4683e --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorResponseTests.java-e @@ -0,0 +1,98 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collection; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.plugins.Plugin; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.test.InternalSettingsPlugin; +import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +public class GetAnomalyDetectorResponseTests extends OpenSearchSingleNodeTestCase { + + @Override + protected Collection> getPlugins() { + return pluginList(InternalSettingsPlugin.class, TimeSeriesAnalyticsPlugin.class); + } + + @Override + protected NamedWriteableRegistry writableRegistry() { + return getInstanceFromNode(NamedWriteableRegistry.class); + } + + public void testConstructor() throws IOException { + GetAnomalyDetectorResponse response = createGetAnomalyDetectorResponse(false, false); + assertNull(response.getAdJob()); + assertNull(response.getRealtimeAdTask()); + assertNull(response.getHistoricalAdTask()); + response = createGetAnomalyDetectorResponse(true, true); + assertNotNull(response.getAdJob()); + assertNotNull(response.getRealtimeAdTask()); + assertNotNull(response.getHistoricalAdTask()); + } + + public void testSerializationWithoutJobAndTask() throws IOException { + GetAnomalyDetectorResponse response = createGetAnomalyDetectorResponse(false, false); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + GetAnomalyDetectorResponse parsedResponse = new GetAnomalyDetectorResponse(input); + assertNull(parsedResponse.getAdJob()); + assertNull(parsedResponse.getRealtimeAdTask()); + assertNull(parsedResponse.getHistoricalAdTask()); + assertEquals(response.getDetector(), parsedResponse.getDetector()); + } + + public void testSerializationWithJobAndTask() throws IOException { + GetAnomalyDetectorResponse response = createGetAnomalyDetectorResponse(true, true); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + GetAnomalyDetectorResponse parsedResponse = new GetAnomalyDetectorResponse(input); + assertNotNull(parsedResponse.getAdJob()); + assertNotNull(parsedResponse.getRealtimeAdTask()); + assertNotNull(parsedResponse.getHistoricalAdTask()); + assertEquals(response.getDetector(), parsedResponse.getDetector()); + } + + private GetAnomalyDetectorResponse createGetAnomalyDetectorResponse(boolean returnJob, boolean returnTask) throws IOException { + GetAnomalyDetectorResponse response = new GetAnomalyDetectorResponse( + randomLong(), + randomAlphaOfLength(5), + randomLong(), + randomLong(), + TestHelpers.randomAnomalyDetector(ImmutableList.of(), ImmutableMap.of(), Instant.now().truncatedTo(ChronoUnit.SECONDS)), + TestHelpers.randomAnomalyDetectorJob(), + returnJob, + TestHelpers.randomAdTask(), + TestHelpers.randomAdTask(), + returnTask, + RestStatus.OK, + null, + null, + false + ); + return response; + } +} diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java index 620ed436f..4a3f2a89c 100644 --- a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java @@ -50,9 +50,9 @@ import org.opensearch.ad.util.Throttler; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.index.get.GetResult; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.constant.CommonMessages; diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java-e b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java-e new file mode 100644 index 000000000..2ad7134f7 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java-e @@ -0,0 +1,244 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.time.Clock; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.function.Consumer; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.get.MultiGetItemResponse; +import org.opensearch.action.get.MultiGetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.ad.util.Throttler; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.get.GetResult; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.transport.Transport; +import org.opensearch.transport.TransportService; + +public class GetAnomalyDetectorTests extends AbstractTimeSeriesTest { + private GetAnomalyDetectorTransportAction action; + private TransportService transportService; + private DiscoveryNodeFilterer nodeFilter; + private ActionFilters actionFilters; + private Client client; + private SecurityClientUtil clientUtil; + private GetAnomalyDetectorRequest request; + private String detectorId = "yecrdnUBqurvo9uKU_d8"; + private String entityValue = "app_0"; + private String categoryField = "categoryField"; + private String typeStr; + private String rawPath; + private PlainActionFuture future; + private ADTaskManager adTaskManager; + private Entity entity; + + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(EntityProfileTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + + @Override + public void setUp() throws Exception { + super.setUp(); + ClusterService clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES))) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + transportService = new TransportService( + Settings.EMPTY, + mock(Transport.class), + null, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + x -> null, + null, + Collections.emptySet() + ); + + nodeFilter = mock(DiscoveryNodeFilterer.class); + + actionFilters = mock(ActionFilters.class); + + client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + + Clock clock = mock(Clock.class); + Throttler throttler = new Throttler(clock); + + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); + + adTaskManager = mock(ADTaskManager.class); + + action = new GetAnomalyDetectorTransportAction( + transportService, + nodeFilter, + actionFilters, + clusterService, + client, + clientUtil, + Settings.EMPTY, + xContentRegistry(), + adTaskManager + ); + + entity = Entity.createSingleAttributeEntity(categoryField, entityValue); + } + + public void testInvalidRequest() throws IOException { + typeStr = "entity_info2,init_progress2"; + + rawPath = "_opendistro/_anomaly_detection/detectors/T4c3dXUBj-2IZN7itix_/_profile"; + + request = new GetAnomalyDetectorRequest(detectorId, 0L, false, false, typeStr, rawPath, false, entity); + + future = new PlainActionFuture<>(); + action.doExecute(null, request, future); + assertException(future, OpenSearchStatusException.class, ADCommonMessages.EMPTY_PROFILES_COLLECT); + } + + @SuppressWarnings("unchecked") + public void testValidRequest() throws IOException { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + GetRequest request = (GetRequest) args[0]; + ActionListener listener = (ActionListener) args[1]; + + String indexName = request.index(); + if (indexName.equals(CommonName.CONFIG_INDEX)) { + listener.onResponse(null); + } + return null; + }).when(client).get(any(), any()); + + typeStr = "entity_info,init_progress"; + + rawPath = "_opendistro/_anomaly_detection/detectors/T4c3dXUBj-2IZN7itix_/_profile"; + + request = new GetAnomalyDetectorRequest(detectorId, 0L, false, false, typeStr, rawPath, false, entity); + + future = new PlainActionFuture<>(); + action.doExecute(null, request, future); + assertException(future, OpenSearchStatusException.class, CommonMessages.FAIL_TO_FIND_CONFIG_MSG); + } + + public void testGetTransportActionWithReturnTask() { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + Consumer> consumer = (Consumer>) args[4]; + + consumer.accept(createADTaskList()); + return null; + }) + .when(adTaskManager) + .getAndExecuteOnLatestADTasks( + anyString(), + eq(null), + eq(null), + anyList(), + any(), + eq(transportService), + eq(true), + anyInt(), + any() + ); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + + listener.onResponse(createMultiGetResponse()); + return null; + }).when(client).multiGet(any(), any()); + + rawPath = "_opendistro/_anomaly_detection/detectors/T4c3dXUBj-2IZN7itix_"; + + request = new GetAnomalyDetectorRequest(detectorId, 0L, false, true, typeStr, rawPath, false, entity); + future = new PlainActionFuture<>(); + action.getExecute(request, future); + + verify(client).multiGet(any(), any()); + } + + private MultiGetResponse createMultiGetResponse() { + MultiGetItemResponse[] items = new MultiGetItemResponse[2]; + ByteBuffer[] buffers = new ByteBuffer[0]; + items[0] = new MultiGetItemResponse( + new GetResponse( + new GetResult(CommonName.JOB_INDEX, "test_1", 1, 1, 0, true, BytesReference.fromByteBuffers(buffers), null, null) + ), + null + ); + items[1] = new MultiGetItemResponse( + new GetResponse( + new GetResult(CommonName.JOB_INDEX, "test_2", 1, 1, 0, true, BytesReference.fromByteBuffers(buffers), null, null) + ), + null + ); + return new MultiGetResponse(items); + } + + private List createADTaskList() { + ADTask adTask1 = new ADTask.Builder().taskId("test1").taskType(ADTaskType.REALTIME_SINGLE_ENTITY.name()).build(); + ADTask adTask2 = new ADTask.Builder().taskId("test2").taskType(ADTaskType.REALTIME_SINGLE_ENTITY.name()).build(); + ADTask adTask3 = new ADTask.Builder().taskId("test3").taskType(ADTaskType.REALTIME_HC_DETECTOR.name()).build(); + ADTask adTask4 = new ADTask.Builder().taskId("test4").taskType(ADTaskType.HISTORICAL_HC_DETECTOR.name()).build(); + ADTask adTask5 = new ADTask.Builder().taskId("test5").taskType(ADTaskType.HISTORICAL_SINGLE_ENTITY.name()).build(); + + return Arrays.asList(adTask1, adTask2, adTask3, adTask4, adTask5); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java index 316308e55..34f1485c2 100644 --- a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java @@ -38,14 +38,14 @@ import org.opensearch.ad.util.*; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; -import org.opensearch.common.io.stream.NamedWriteableRegistry; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.rest.RestStatus; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.threadpool.TestThreadPool; diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java-e b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java-e new file mode 100644 index 000000000..52bf2ce62 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java-e @@ -0,0 +1,264 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import org.junit.*; +import org.mockito.Mockito; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.model.EntityProfile; +import org.opensearch.ad.model.InitProgressProfile; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.util.*; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableMap; + +public class GetAnomalyDetectorTransportActionTests extends OpenSearchSingleNodeTestCase { + private static ThreadPool threadPool; + private GetAnomalyDetectorTransportAction action; + private Task task; + private ActionListener response; + private ADTaskManager adTaskManager; + private Entity entity; + private String categoryField; + private String categoryValue; + + @BeforeClass + public static void beforeCLass() { + threadPool = new TestThreadPool("GetAnomalyDetectorTransportActionTests"); + } + + @AfterClass + public static void afterClass() { + ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); + } + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + ClusterService clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES))) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + adTaskManager = mock(ADTaskManager.class); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + SecurityClientUtil clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); + action = new GetAnomalyDetectorTransportAction( + Mockito.mock(TransportService.class), + Mockito.mock(DiscoveryNodeFilterer.class), + Mockito.mock(ActionFilters.class), + clusterService, + client(), + clientUtil, + Settings.EMPTY, + xContentRegistry(), + adTaskManager + ); + task = Mockito.mock(Task.class); + response = new ActionListener() { + @Override + public void onResponse(GetAnomalyDetectorResponse getResponse) { + // When no detectors exist, get response is not generated + assertTrue(true); + } + + @Override + public void onFailure(Exception e) {} + }; + categoryField = "catField"; + categoryValue = "app-0"; + entity = Entity.createSingleAttributeEntity(categoryField, categoryValue); + } + + @Override + protected NamedWriteableRegistry writableRegistry() { + return getInstanceFromNode(NamedWriteableRegistry.class); + } + + @Test + public void testGetTransportAction() throws IOException { + GetAnomalyDetectorRequest getAnomalyDetectorRequest = new GetAnomalyDetectorRequest( + "1234", + 4321, + false, + false, + "nonempty", + "", + false, + null + ); + action.doExecute(task, getAnomalyDetectorRequest, response); + } + + @Test + public void testGetTransportActionWithReturnJob() throws IOException { + GetAnomalyDetectorRequest getAnomalyDetectorRequest = new GetAnomalyDetectorRequest( + "1234", + 4321, + true, + false, + "", + "abcd", + false, + null + ); + action.doExecute(task, getAnomalyDetectorRequest, response); + } + + @Test + public void testGetAction() { + Assert.assertNotNull(GetAnomalyDetectorAction.INSTANCE.name()); + Assert.assertEquals(GetAnomalyDetectorAction.INSTANCE.name(), GetAnomalyDetectorAction.NAME); + } + + @Test + public void testGetAnomalyDetectorRequest() throws IOException { + GetAnomalyDetectorRequest request = new GetAnomalyDetectorRequest("1234", 4321, true, false, "", "abcd", false, entity); + BytesStreamOutput out = new BytesStreamOutput(); + request.writeTo(out); + StreamInput input = out.bytes().streamInput(); + GetAnomalyDetectorRequest newRequest = new GetAnomalyDetectorRequest(input); + Assert.assertEquals(request.getDetectorID(), newRequest.getDetectorID()); + Assert.assertEquals(request.getRawPath(), newRequest.getRawPath()); + Assert.assertNull(newRequest.validate()); + } + + @Test + public void testGetAnomalyDetectorRequestNoEntityValue() throws IOException { + GetAnomalyDetectorRequest request = new GetAnomalyDetectorRequest("1234", 4321, true, false, "", "abcd", false, null); + BytesStreamOutput out = new BytesStreamOutput(); + request.writeTo(out); + StreamInput input = out.bytes().streamInput(); + GetAnomalyDetectorRequest newRequest = new GetAnomalyDetectorRequest(input); + Assert.assertNull(newRequest.getEntity()); + } + + @SuppressWarnings("unchecked") + @Test + public void testGetAnomalyDetectorResponse() throws IOException { + BytesStreamOutput out = new BytesStreamOutput(); + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of("testKey", "testValue"), Instant.now()); + AnomalyDetectorJob adJob = TestHelpers.randomAnomalyDetectorJob(); + GetAnomalyDetectorResponse response = new GetAnomalyDetectorResponse( + 4321, + "1234", + 5678, + 9867, + detector, + adJob, + false, + mock(ADTask.class), + mock(ADTask.class), + false, + RestStatus.OK, + null, + null, + false + ); + response.writeTo(out); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), writableRegistry()); + GetAnomalyDetectorResponse newResponse = new GetAnomalyDetectorResponse(input); + XContentBuilder builder = TestHelpers.builder(); + Assert.assertNotNull(newResponse.toXContent(builder, ToXContent.EMPTY_PARAMS)); + + Map map = TestHelpers.XContentBuilderToMap(builder); + Assert.assertTrue(map.get(RestHandlerUtils.ANOMALY_DETECTOR) instanceof Map); + Map map1 = (Map) map.get(RestHandlerUtils.ANOMALY_DETECTOR); + Assert.assertEquals(map1.get("name"), detector.getName()); + } + + @SuppressWarnings("unchecked") + @Test + public void testGetAnomalyDetectorProfileResponse() throws IOException { + BytesStreamOutput out = new BytesStreamOutput(); + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of("testKey", "testValue"), Instant.now()); + AnomalyDetectorJob adJob = TestHelpers.randomAnomalyDetectorJob(); + InitProgressProfile initProgress = new InitProgressProfile("99%", 2L, 2); + EntityProfile entityProfile = new EntityProfile.Builder().initProgress(initProgress).build(); + GetAnomalyDetectorResponse response = new GetAnomalyDetectorResponse( + 4321, + "1234", + 5678, + 9867, + detector, + adJob, + false, + mock(ADTask.class), + mock(ADTask.class), + false, + RestStatus.OK, + null, + entityProfile, + true + ); + response.writeTo(out); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), writableRegistry()); + GetAnomalyDetectorResponse newResponse = new GetAnomalyDetectorResponse(input); + XContentBuilder builder = TestHelpers.builder(); + Assert.assertNotNull(newResponse.toXContent(builder, ToXContent.EMPTY_PARAMS)); + + // {init_progress={percentage=99%, estimated_minutes_left=2, needed_shingles=2}} + Map map = TestHelpers.XContentBuilderToMap(builder); + Map parsedInitProgress = (Map) (map.get(ADCommonName.INIT_PROGRESS)); + Assert.assertEquals(initProgress.getPercentage(), parsedInitProgress.get(InitProgressProfile.PERCENTAGE).toString()); + assertTrue(initProgress.toString().contains("[percentage=99%,estimated_minutes_left=2,needed_shingles=2]")); + Assert + .assertEquals( + String.valueOf(initProgress.getEstimatedMinutesLeft()), + parsedInitProgress.get(InitProgressProfile.ESTIMATED_MINUTES_LEFT).toString() + ); + Assert + .assertEquals( + String.valueOf(initProgress.getNeededDataPoints()), + parsedInitProgress.get(InitProgressProfile.NEEDED_SHINGLES).toString() + ); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorActionTests.java b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorActionTests.java index 61fcbb50c..f29030912 100644 --- a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorActionTests.java @@ -20,13 +20,13 @@ import org.opensearch.action.support.WriteRequest; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; -import org.opensearch.common.io.stream.NamedWriteableRegistry; import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.rest.RestRequest; -import org.opensearch.rest.RestStatus; import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.util.RestHandlerUtils; diff --git a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorActionTests.java-e b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorActionTests.java-e new file mode 100644 index 000000000..ce262ef43 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorActionTests.java-e @@ -0,0 +1,85 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.time.Instant; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.rest.RestRequest; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.util.RestHandlerUtils; + +import com.google.common.collect.ImmutableMap; + +public class IndexAnomalyDetectorActionTests extends OpenSearchSingleNodeTestCase { + @Before + public void setUp() throws Exception { + super.setUp(); + } + + @Override + protected NamedWriteableRegistry writableRegistry() { + return getInstanceFromNode(NamedWriteableRegistry.class); + } + + @Test + public void testIndexRequest() throws Exception { + BytesStreamOutput out = new BytesStreamOutput(); + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of("testKey", "testValue"), Instant.now()); + IndexAnomalyDetectorRequest request = new IndexAnomalyDetectorRequest( + "1234", + 4321, + 5678, + WriteRequest.RefreshPolicy.NONE, + detector, + RestRequest.Method.PUT, + TimeValue.timeValueSeconds(60), + 1000, + 10, + 5 + ); + request.writeTo(out); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), writableRegistry()); + IndexAnomalyDetectorRequest newRequest = new IndexAnomalyDetectorRequest(input); + Assert.assertEquals(request.getDetectorID(), newRequest.getDetectorID()); + Assert.assertNull(newRequest.validate()); + } + + @Test + public void testIndexResponse() throws Exception { + BytesStreamOutput out = new BytesStreamOutput(); + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of("testKey", "testValue"), Instant.now()); + IndexAnomalyDetectorResponse response = new IndexAnomalyDetectorResponse("1234", 56, 78, 90, detector, RestStatus.OK); + response.writeTo(out); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), writableRegistry()); + IndexAnomalyDetectorResponse newResponse = new IndexAnomalyDetectorResponse(input); + Assert.assertEquals(response.getId(), newResponse.getId()); + XContentBuilder builder = TestHelpers.builder(); + Assert.assertNotNull(newResponse.toXContent(builder, ToXContent.EMPTY_PARAMS)); + + Map map = TestHelpers.XContentBuilderToMap(builder); + Assert.assertEquals(map.get(RestHandlerUtils._ID), "1234"); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java-e b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java-e new file mode 100644 index 000000000..0a3859bc2 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java-e @@ -0,0 +1,256 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Locale; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.rest.RestRequest; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableMap; + +public class IndexAnomalyDetectorTransportActionTests extends OpenSearchIntegTestCase { + private IndexAnomalyDetectorTransportAction action; + private Task task; + private IndexAnomalyDetectorRequest request; + private ActionListener response; + private ClusterService clusterService; + private ClusterSettings clusterSettings; + private ADTaskManager adTaskManager; + private Client client = mock(Client.class); + private SecurityClientUtil clientUtil; + private SearchFeatureDao searchFeatureDao; + + @SuppressWarnings("unchecked") + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + clusterService = mock(ClusterService.class); + clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES))) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + ClusterName clusterName = new ClusterName("test"); + Settings indexSettings = Settings + .builder() + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + .build(); + final Settings.Builder existingSettings = Settings.builder().put(indexSettings).put(IndexMetadata.SETTING_INDEX_UUID, "test2UUID"); + IndexMetadata indexMetaData = IndexMetadata.builder(CommonName.CONFIG_INDEX).settings(existingSettings).build(); + final Map indices = new HashMap<>(); + indices.put(CommonName.CONFIG_INDEX, indexMetaData); + ClusterState clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().indices(indices).build()).build(); + when(clusterService.state()).thenReturn(clusterState); + + adTaskManager = mock(ADTaskManager.class); + searchFeatureDao = mock(SearchFeatureDao.class); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); + action = new IndexAnomalyDetectorTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + client(), + clientUtil, + clusterService, + indexSettings(), + mock(ADIndexManagement.class), + xContentRegistry(), + adTaskManager, + searchFeatureDao + ); + task = mock(Task.class); + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of("testKey", "testValue"), Instant.now()); + GetResponse getDetectorResponse = TestHelpers.createGetResponse(detector, detector.getId(), CommonName.CONFIG_INDEX); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length == 2 + ); + + assertTrue(args[0] instanceof GetRequest); + assertTrue(args[1] instanceof ActionListener); + + ActionListener listener = (ActionListener) args[1]; + listener.onResponse(getDetectorResponse); + return null; + }).when(client).get(any(GetRequest.class), any()); + + SearchHits hits = new SearchHits(new SearchHit[] {}, null, Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections(hits, null, null, false, false, null, 1); + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 30, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length == 2 + ); + + assertTrue(args[0] instanceof SearchRequest); + assertTrue(args[1] instanceof ActionListener); + + ActionListener listener = (ActionListener) args[1]; + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), any()); + + request = new IndexAnomalyDetectorRequest( + "1234", + 4567, + 7890, + WriteRequest.RefreshPolicy.IMMEDIATE, + detector, + RestRequest.Method.PUT, + TimeValue.timeValueSeconds(60), + 1000, + 10, + 5 + ); + response = new ActionListener() { + @Override + public void onResponse(IndexAnomalyDetectorResponse indexResponse) { + // onResponse will not be called as we do not have the AD index + Assert.assertTrue(false); + } + + @Override + public void onFailure(Exception e) { + Assert.assertTrue(true); + } + }; + } + + @Test + public void testIndexTransportAction() { + action.doExecute(task, request, response); + } + + @Test + public void testIndexTransportActionWithUserAndFilterOn() { + Settings settings = Settings.builder().put(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.getKey(), true).build(); + ThreadContext threadContext = new ThreadContext(settings); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alice|odfe,aes|engineering,operations"); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + org.opensearch.threadpool.ThreadPool mockThreadPool = mock(ThreadPool.class); + when(client.threadPool()).thenReturn(mockThreadPool); + when(mockThreadPool.getThreadContext()).thenReturn(threadContext); + + IndexAnomalyDetectorTransportAction transportAction = new IndexAnomalyDetectorTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + client, + clientUtil, + clusterService, + settings, + mock(ADIndexManagement.class), + xContentRegistry(), + adTaskManager, + searchFeatureDao + + ); + transportAction.doExecute(task, request, response); + } + + @Test + public void testIndexTransportActionWithUserAndFilterOff() { + Settings settings = Settings.builder().build(); + ThreadContext threadContext = new ThreadContext(settings); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alice|odfe,aes|engineering,operations"); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + org.opensearch.threadpool.ThreadPool mockThreadPool = mock(ThreadPool.class); + when(client.threadPool()).thenReturn(mockThreadPool); + when(mockThreadPool.getThreadContext()).thenReturn(threadContext); + + IndexAnomalyDetectorTransportAction transportAction = new IndexAnomalyDetectorTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + client, + clientUtil, + clusterService, + settings, + mock(ADIndexManagement.class), + xContentRegistry(), + adTaskManager, + searchFeatureDao + ); + transportAction.doExecute(task, request, response); + } + + @Test + public void testIndexDetectorAction() { + Assert.assertNotNull(IndexAnomalyDetectorAction.INSTANCE.name()); + Assert.assertEquals(IndexAnomalyDetectorAction.INSTANCE.name(), IndexAnomalyDetectorAction.NAME); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java index 6c2a88240..9ec5aa9d5 100644 --- a/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java @@ -98,12 +98,12 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodeRole; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.index.IndexNotFoundException; import org.opensearch.search.DocValueFormat; import org.opensearch.search.SearchHits; diff --git a/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java-e b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java-e new file mode 100644 index 000000000..feb9d1985 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java-e @@ -0,0 +1,1419 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.hamcrest.Matchers.containsString; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Matchers.argThat; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.PAGE_SIZE; + +import java.io.IOException; +import java.time.Clock; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; + +import org.apache.commons.lang3.tuple.Pair; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatcher; +import org.mockito.stubbing.Answer; +import org.opensearch.OpenSearchTimeoutException; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.search.SearchPhaseExecutionException; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponse.Clusters; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.feature.CompositeRetriever; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.ratelimit.CheckpointReadWorker; +import org.opensearch.ad.ratelimit.ColdEntityWorker; +import org.opensearch.ad.ratelimit.EntityColdStartWorker; +import org.opensearch.ad.ratelimit.EntityFeatureRequest; +import org.opensearch.ad.ratelimit.ResultWriteWorker; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.stats.ADStat; +import org.opensearch.ad.stats.ADStats; +import org.opensearch.ad.stats.suppliers.CounterSupplier; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.ad.util.Throttler; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodeRole; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; +import org.opensearch.search.aggregations.metrics.InternalMin; +import org.opensearch.test.ClusterServiceUtils; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.InternalFailure; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.transport.Transport; +import org.opensearch.transport.TransportException; +import org.opensearch.transport.TransportInterceptor; +import org.opensearch.transport.TransportRequest; +import org.opensearch.transport.TransportRequestOptions; +import org.opensearch.transport.TransportResponse; +import org.opensearch.transport.TransportResponseHandler; +import org.opensearch.transport.TransportService; + +import test.org.opensearch.ad.util.MLUtil; +import test.org.opensearch.ad.util.RandomModelStateConfig; + +import com.google.common.collect.ImmutableList; + +public class MultiEntityResultTests extends AbstractTimeSeriesTest { + private AnomalyResultTransportAction action; + private AnomalyResultRequest request; + private TransportInterceptor entityResultInterceptor; + private Clock clock; + private AnomalyDetector detector; + private NodeStateManager stateManager; + private static Settings settings; + private TransportService transportService; + private Client client; + private SecurityClientUtil clientUtil; + private FeatureManager featureQuery; + private ModelManager normalModelManager; + private HashRing hashRing; + private ClusterService clusterService; + private IndexNameExpressionResolver indexNameResolver; + private ADCircuitBreakerService adCircuitBreakerService; + private ADStats adStats; + private ThreadPool mockThreadPool; + private String detectorId; + private Instant now; + private CacheProvider provider; + private ADIndexManagement indexUtil; + private ResultWriteWorker resultWriteQueue; + private CheckpointReadWorker checkpointReadQueue; + private EntityColdStartWorker entityColdStartQueue; + private ColdEntityWorker coldEntityQueue; + private String app0 = "app_0"; + private String server1 = "server_1"; + private String server2 = "server_2"; + private String server3 = "server_3"; + private String serviceField = "service"; + private String hostField = "host"; + private Map attrs1, attrs2, attrs3; + private EntityCache entityCache; + private ADTaskManager adTaskManager; + + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(AnomalyResultTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + + @SuppressWarnings({ "serial", "unchecked" }) + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + now = Instant.now(); + clock = mock(Clock.class); + when(clock.instant()).thenReturn(now); + + detectorId = "123"; + String categoryField = "a"; + detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Collections.singletonList(categoryField)); + + stateManager = mock(NodeStateManager.class); + // make sure parameters are not null, otherwise this mock won't get invoked + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(stateManager).getAnomalyDetector(anyString(), any(ActionListener.class)); + + settings = Settings.builder().put(AnomalyDetectorSettings.COOLDOWN_MINUTES.getKey(), TimeValue.timeValueMinutes(5)).build(); + + // make sure end time is larger enough than Clock.systemUTC().millis() to get PageIterator.hasNext() to pass + request = new AnomalyResultRequest(detectorId, 100, Clock.systemUTC().millis() + 100_000); + + transportService = mock(TransportService.class); + + client = mock(Client.class); + ThreadContext threadContext = new ThreadContext(settings); + mockThreadPool = mock(ThreadPool.class); + setUpADThreadPool(mockThreadPool); + when(client.threadPool()).thenReturn(mockThreadPool); + when(mockThreadPool.getThreadContext()).thenReturn(threadContext); + clientUtil = new SecurityClientUtil(stateManager, settings); + + featureQuery = mock(FeatureManager.class); + + normalModelManager = mock(ModelManager.class); + + hashRing = mock(HashRing.class); + + Set> anomalyResultSetting = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + anomalyResultSetting.add(MAX_ENTITIES_PER_QUERY); + anomalyResultSetting.add(PAGE_SIZE); + anomalyResultSetting.add(MAX_RETRY_FOR_UNRESPONSIVE_NODE); + anomalyResultSetting.add(BACKOFF_MINUTES); + ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, anomalyResultSetting); + + DiscoveryNode discoveryNode = new DiscoveryNode( + "node1", + OpenSearchTestCase.buildNewFakeTransportAddress(), + Collections.emptyMap(), + DiscoveryNodeRole.BUILT_IN_ROLES, + Version.CURRENT + ); + + clusterService = ClusterServiceUtils.createClusterService(threadPool, discoveryNode, clusterSettings); + + indexNameResolver = new IndexNameExpressionResolver(new ThreadContext(Settings.EMPTY)); + + adCircuitBreakerService = mock(ADCircuitBreakerService.class); + when(adCircuitBreakerService.isOpen()).thenReturn(false); + + Map> statsMap = new HashMap>() { + { + put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.MODEL_CORRUTPION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + } + }; + adStats = new ADStats(statsMap); + + adTaskManager = mock(ADTaskManager.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(true); + return null; + }) + .when(adTaskManager) + .initRealtimeTaskCacheAndCleanupStaleCache( + anyString(), + any(AnomalyDetector.class), + any(TransportService.class), + any(ActionListener.class) + ); + + action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + mockThreadPool, + xContentRegistry(), + adTaskManager + ); + + provider = mock(CacheProvider.class); + entityCache = mock(EntityCache.class); + when(provider.get()).thenReturn(entityCache); + when(entityCache.get(any(), any())) + .thenReturn(MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build())); + when(entityCache.selectUpdateCandidate(any(), any(), any())).thenReturn(Pair.of(new ArrayList(), new ArrayList())); + + indexUtil = mock(ADIndexManagement.class); + resultWriteQueue = mock(ResultWriteWorker.class); + checkpointReadQueue = mock(CheckpointReadWorker.class); + entityColdStartQueue = mock(EntityColdStartWorker.class); + + coldEntityQueue = mock(ColdEntityWorker.class); + + attrs1 = new HashMap<>(); + attrs1.put(serviceField, app0); + attrs1.put(hostField, server1); + + attrs2 = new HashMap<>(); + attrs2.put(serviceField, app0); + attrs2.put(hostField, server2); + + attrs3 = new HashMap<>(); + attrs3.put(serviceField, app0); + attrs3.put(hostField, server3); + } + + @Override + @After + public final void tearDown() throws Exception { + tearDownTestNodes(); + super.tearDown(); + } + + public void testColdStartEndRunException() { + when(stateManager.fetchExceptionAndClear(anyString())) + .thenReturn( + Optional + .of( + new EndRunException( + detectorId, + CommonMessages.INVALID_SEARCH_QUERY_MSG, + new NoSuchElementException("No value present"), + false + ) + ) + ); + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + assertException(listener, EndRunException.class, CommonMessages.INVALID_SEARCH_QUERY_MSG); + } + + // a handler that forwards response or exception received from network + private TransportResponseHandler entityResultHandler(TransportResponseHandler handler) { + return new TransportResponseHandler() { + @Override + public T read(StreamInput in) throws IOException { + return handler.read(in); + } + + @Override + @SuppressWarnings("unchecked") + public void handleResponse(T response) { + handler.handleResponse(response); + } + + @Override + public void handleException(TransportException exp) { + handler.handleException(exp); + } + + @Override + public String executor() { + return handler.executor(); + } + }; + } + + private TransportResponseHandler unackEntityResultHandler(TransportResponseHandler handler) { + return new TransportResponseHandler() { + @Override + public T read(StreamInput in) throws IOException { + return handler.read(in); + } + + @Override + @SuppressWarnings("unchecked") + public void handleResponse(T response) { + handler.handleResponse((T) new AcknowledgedResponse(false)); + } + + @Override + public void handleException(TransportException exp) { + handler.handleException(exp); + } + + @Override + public String executor() { + return handler.executor(); + } + }; + } + + private void setUpEntityResult(int nodeIndex, NodeStateManager nodeStateManager) { + // register entity result action + new EntityResultTransportAction( + new ActionFilters(Collections.emptySet()), + // since we send requests to testNodes[1] + testNodes[nodeIndex].transportService, + normalModelManager, + adCircuitBreakerService, + provider, + nodeStateManager, + indexUtil, + resultWriteQueue, + checkpointReadQueue, + coldEntityQueue, + threadPool, + entityColdStartQueue, + adStats + ); + + when(normalModelManager.getAnomalyResultForEntity(any(), any(), any(), any(), anyInt())) + .thenReturn(new ThresholdingResult(0, 1, 1)); + } + + private void setUpEntityResult(int nodeIndex) { + setUpEntityResult(nodeIndex, stateManager); + } + + @SuppressWarnings("unchecked") + public void setUpNormlaStateManager() throws IOException { + AnomalyDetector detector = TestHelpers.AnomalyDetectorBuilder + .newInstance() + .setDetectionInterval(new IntervalTimeConfiguration(1, ChronoUnit.MINUTES)) + .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) + .build(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, CommonName.CONFIG_INDEX)); + return null; + }).when(client).get(any(GetRequest.class), any(ActionListener.class)); + + stateManager = new NodeStateManager( + client, + xContentRegistry(), + settings, + new ClientUtil(settings, client, new Throttler(mock(Clock.class)), threadPool), + clock, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + clusterService + ); + + clientUtil = new SecurityClientUtil(stateManager, settings); + + action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + transportService, + settings, + client, + clientUtil, + stateManager, + featureQuery, + normalModelManager, + hashRing, + clusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + mockThreadPool, + xContentRegistry(), + adTaskManager + ); + } + + /** + * Test query error causes EndRunException but not end now + * @throws InterruptedException when the await are interrupted + * @throws IOException when failing to create anomaly detector + */ + public void testQueryErrorEndRunNotNow() throws InterruptedException, IOException { + setUpNormlaStateManager(); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + String allShardsFailedMsg = "all shards failed"; + // make PageIterator.next return failure + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener + .onFailure( + new SearchPhaseExecutionException( + "search", + allShardsFailedMsg, + new ShardSearchFailure[] { new ShardSearchFailure(new IllegalArgumentException("blah")) } + ) + ); + inProgressLatch.countDown(); + return null; + }).when(client).search(any(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.001); + + assertTrue(inProgressLatch.await(10000L, TimeUnit.MILLISECONDS)); + + PlainActionFuture listener2 = new PlainActionFuture<>(); + action.doExecute(null, request, listener2); + Exception e = expectThrows(EndRunException.class, () -> listener2.actionGet(10000L)); + // wrapped INVALID_SEARCH_QUERY_MSG around SearchPhaseExecutionException by convertedQueryFailureException + assertThat("actual message: " + e.getMessage(), e.getMessage(), containsString(CommonMessages.INVALID_SEARCH_QUERY_MSG)); + assertThat("actual message: " + e.getMessage(), e.getMessage(), containsString(allShardsFailedMsg)); + // not end now + assertTrue(!((EndRunException) e).isEndNow()); + } + + public void testIndexNotFound() throws InterruptedException, IOException { + setUpNormlaStateManager(); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + // make PageIterator.next return failure + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new IndexNotFoundException("", "")); + inProgressLatch.countDown(); + return null; + }).when(client).search(any(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.001); + + assertTrue(inProgressLatch.await(10000L, TimeUnit.MILLISECONDS)); + + PlainActionFuture listener2 = new PlainActionFuture<>(); + action.doExecute(null, request, listener2); + Exception e = expectThrows(EndRunException.class, () -> listener2.actionGet(10000L)); + assertThat( + "actual message: " + e.getMessage(), + e.getMessage(), + containsString(AnomalyResultTransportAction.TROUBLE_QUERYING_ERR_MSG) + ); + assertTrue(!((EndRunException) e).isEndNow()); + } + + public void testEmptyFeatures() throws InterruptedException { + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(createEmptyResponse()); + inProgressLatch.countDown(); + return null; + }).when(client).search(any(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.01); + + assertTrue(inProgressLatch.await(10000L, TimeUnit.MILLISECONDS)); + + PlainActionFuture listener2 = new PlainActionFuture<>(); + action.doExecute(null, request, listener2); + + AnomalyResultResponse response2 = listener2.actionGet(10000L); + assertEquals(Double.NaN, response2.getAnomalyGrade(), 0.01); + } + + /** + * + * @return an empty response + */ + private SearchResponse createEmptyResponse() { + CompositeAggregation emptyComposite = mock(CompositeAggregation.class); + when(emptyComposite.getName()).thenReturn(CompositeRetriever.AGG_NAME_COMP); + when(emptyComposite.afterKey()).thenReturn(null); + // empty bucket + when(emptyComposite.getBuckets()) + .thenAnswer((Answer>) invocation -> { return new ArrayList(); }); + Aggregations emptyAggs = new Aggregations(Collections.singletonList(emptyComposite)); + SearchResponseSections emptySections = new SearchResponseSections(SearchHits.empty(), emptyAggs, null, false, null, null, 1); + return new SearchResponse(emptySections, null, 1, 1, 0, 0, ShardSearchFailure.EMPTY_ARRAY, Clusters.EMPTY); + } + + private void setUpSearchResponse() throws IOException { + detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(serviceField, hostField)); + // set up a non-empty response + CompositeAggregation composite = mock(CompositeAggregation.class); + when(composite.getName()).thenReturn(CompositeRetriever.AGG_NAME_COMP); + when(composite.afterKey()).thenReturn(attrs3); + + String featureID = detector.getFeatureAttributes().get(0).getId(); + List compositeBuckets = new ArrayList<>(); + CompositeAggregation.Bucket bucket = mock(CompositeAggregation.Bucket.class); + when(bucket.getKey()).thenReturn(attrs1); + List aggList = new ArrayList<>(); + aggList.add(new InternalMin(featureID, randomDouble(), DocValueFormat.RAW, new HashMap<>())); + Aggregations aggregations = new Aggregations(aggList); + when(bucket.getAggregations()).thenReturn(aggregations); + compositeBuckets.add(bucket); + + bucket = mock(CompositeAggregation.Bucket.class); + when(bucket.getKey()).thenReturn(attrs2); + aggList = new ArrayList<>(); + aggList.add(new InternalMin(featureID, randomDouble(), DocValueFormat.RAW, new HashMap<>())); + aggregations = new Aggregations(aggList); + when(bucket.getAggregations()).thenReturn(aggregations); + compositeBuckets.add(bucket); + + bucket = mock(CompositeAggregation.Bucket.class); + when(bucket.getKey()).thenReturn(attrs3); + aggList = new ArrayList<>(); + aggList.add(new InternalMin(featureID, randomDouble(), DocValueFormat.RAW, new HashMap<>())); + aggregations = new Aggregations(aggList); + when(bucket.getAggregations()).thenReturn(aggregations); + compositeBuckets.add(bucket); + + when(composite.getBuckets()).thenAnswer((Answer>) invocation -> { return compositeBuckets; }); + Aggregations aggs = new Aggregations(Collections.singletonList(composite)); + + SearchResponseSections sections = new SearchResponseSections(SearchHits.empty(), aggs, null, false, null, null, 1); + SearchResponse response = new SearchResponse(sections, null, 1, 1, 0, 0, ShardSearchFailure.EMPTY_ARRAY, Clusters.EMPTY); + + AtomicBoolean firstCalled = new AtomicBoolean(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + if (firstCalled.get()) { + listener.onResponse(createEmptyResponse()); + } else { + // set firstCalled to be true before returning in case that listener return + // and the 2nd call comes in before firstCalled is set to true. Then we + // have the 2nd response. + firstCalled.set(true); + listener.onResponse(response); + } + return null; + }).when(client).search(any(), any()); + } + + private void setUpTransportInterceptor( + Function, TransportResponseHandler> interceptor, + NodeStateManager nodeStateManager + ) { + entityResultInterceptor = new TransportInterceptor() { + @Override + public AsyncSender interceptSender(AsyncSender sender) { + return new AsyncSender() { + @SuppressWarnings("unchecked") + @Override + public void sendRequest( + Transport.Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + if (action.equals(EntityResultAction.NAME)) { + sender + .sendRequest( + connection, + action, + request, + options, + interceptor.apply((TransportResponseHandler) handler) + ); + } else { + sender.sendRequest(connection, action, request, options, handler); + } + } + }; + } + }; + + // we start support multi-category fields since 1.1 + // Set version to 1.1 will force the outbound/inbound message to use 1.1 version + setupTestNodes(entityResultInterceptor, 5, settings, Version.V_2_0_0, MAX_ENTITIES_PER_QUERY, PAGE_SIZE); + + TransportService realTransportService = testNodes[0].transportService; + ClusterService realClusterService = testNodes[0].clusterService; + + action = new AnomalyResultTransportAction( + new ActionFilters(Collections.emptySet()), + realTransportService, + settings, + client, + clientUtil, + nodeStateManager, + featureQuery, + normalModelManager, + hashRing, + realClusterService, + indexNameResolver, + adCircuitBreakerService, + adStats, + threadPool, + xContentRegistry(), + adTaskManager + ); + } + + private void setUpTransportInterceptor( + Function, TransportResponseHandler> interceptor + ) { + setUpTransportInterceptor(interceptor, stateManager); + } + + public void testNonEmptyFeatures() throws InterruptedException, IOException { + setUpSearchResponse(); + setUpTransportInterceptor(this::entityResultHandler); + // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + .thenReturn(Optional.of(testNodes[1].discoveryNode())); + setUpEntityResult(1); + + CountDownLatch modelNodeInProgress = new CountDownLatch(1); + doAnswer(invocation -> { + if (modelNodeInProgress.getCount() == 1) { + modelNodeInProgress.countDown(); + } + return null; + }).when(coldEntityQueue).putAll(any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.01); + + assertTrue(modelNodeInProgress.await(10000L, TimeUnit.MILLISECONDS)); + + // since we have 3 results in the first page + verify(resultWriteQueue, times(3)).put(any()); + } + + @SuppressWarnings("unchecked") + public void testCircuitBreakerOpen() throws InterruptedException, IOException { + ClientUtil clientUtil = mock(ClientUtil.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, CommonName.CONFIG_INDEX)); + return null; + }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); + + stateManager = new NodeStateManager( + client, + xContentRegistry(), + settings, + clientUtil, + clock, + AnomalyDetectorSettings.HOURLY_MAINTENANCE, + clusterService + ); + + NodeStateManager spyStateManager = spy(stateManager); + + setUpSearchResponse(); + setUpTransportInterceptor(this::entityResultHandler, spyStateManager); + // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + .thenReturn(Optional.of(testNodes[1].discoveryNode())); + + ADCircuitBreakerService openBreaker = mock(ADCircuitBreakerService.class); + when(openBreaker.isOpen()).thenReturn(true); + + // register entity result action + new EntityResultTransportAction( + new ActionFilters(Collections.emptySet()), + // since we send requests to testNodes[1] + testNodes[1].transportService, + normalModelManager, + openBreaker, + provider, + spyStateManager, + indexUtil, + resultWriteQueue, + checkpointReadQueue, + coldEntityQueue, + threadPool, + entityColdStartQueue, + adStats + ); + + CountDownLatch inProgress = new CountDownLatch(1); + doAnswer(invocation -> { + String id = invocation.getArgument(0); + Exception exp = invocation.getArgument(1); + + stateManager.setException(id, exp); + inProgress.countDown(); + return null; + }).when(spyStateManager).setException(any(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.01); + + assertTrue(inProgress.await(10000L, TimeUnit.MILLISECONDS)); + + listener = new PlainActionFuture<>(); + action.doExecute(null, request, listener); + assertException(listener, LimitExceededException.class, CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG); + } + + public void testNotAck() throws InterruptedException, IOException { + setUpSearchResponse(); + setUpTransportInterceptor(this::unackEntityResultHandler); + // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + .thenReturn(Optional.of(testNodes[1].discoveryNode())); + setUpEntityResult(1); + + CountDownLatch inProgress = new CountDownLatch(1); + doAnswer(invocation -> { + inProgress.countDown(); + return null; + }).when(stateManager).addPressure(anyString(), anyString()); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.01); + + assertTrue(inProgress.await(10000L, TimeUnit.MILLISECONDS)); + + verify(stateManager, times(1)).addPressure(anyString(), anyString()); + } + + public void testMultipleNode() throws InterruptedException, IOException { + setUpSearchResponse(); + setUpTransportInterceptor(this::entityResultHandler); + + Entity entity1 = Entity.createEntityByReordering(attrs1); + Entity entity2 = Entity.createEntityByReordering(attrs2); + Entity entity3 = Entity.createEntityByReordering(attrs3); + + // we use ordered attributes values as the key to hashring + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(eq(entity1.toString()))) + .thenReturn(Optional.of(testNodes[2].discoveryNode())); + + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(eq(entity2.toString()))) + .thenReturn(Optional.of(testNodes[3].discoveryNode())); + + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(eq(entity3.toString()))) + .thenReturn(Optional.of(testNodes[4].discoveryNode())); + + for (int i = 2; i <= 4; i++) { + setUpEntityResult(i); + } + + CountDownLatch modelNodeInProgress = new CountDownLatch(3); + doAnswer(invocation -> { + modelNodeInProgress.countDown(); + return null; + }).when(coldEntityQueue).putAll(any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.01); + + assertTrue(modelNodeInProgress.await(10000L, TimeUnit.MILLISECONDS)); + + // since we have 3 results in the first page + verify(resultWriteQueue, times(3)).put(any()); + } + + public void testCacheSelectionError() throws IOException, InterruptedException { + setUpSearchResponse(); + setUpTransportInterceptor(this::entityResultHandler); + setUpEntityResult(1); + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + .thenReturn(Optional.of(testNodes[1].discoveryNode())); + + List hotEntities = new ArrayList<>(); + Map attrs4 = new HashMap<>(); + attrs4.put(serviceField, app0); + attrs4.put(hostField, "server_4"); + Entity entity4 = Entity.createEntityByReordering(attrs4); + hotEntities.add(entity4); + + List coldEntities = new ArrayList<>(); + Map attrs5 = new HashMap<>(); + attrs5.put(serviceField, app0); + attrs5.put(hostField, "server_5"); + Entity entity5 = Entity.createEntityByReordering(attrs5); + coldEntities.add(entity5); + + when(entityCache.selectUpdateCandidate(any(), any(), any())).thenReturn(Pair.of(hotEntities, coldEntities)); + + CountDownLatch modelNodeInProgress = new CountDownLatch(1); + doAnswer(invocation -> { + if (modelNodeInProgress.getCount() == 1) { + modelNodeInProgress.countDown(); + } + return null; + }).when(coldEntityQueue).putAll(any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + assertTrue(modelNodeInProgress.await(10000L, TimeUnit.MILLISECONDS)); + // size 0 because cacheMissEntities has no record of these entities + verify(checkpointReadQueue).putAll(argThat(new ArgumentMatcher>() { + + @Override + public boolean matches(List argument) { + List arg = (argument); + LOG.info("size: " + arg.size()); + return arg.size() == 0; + } + })); + + verify(coldEntityQueue).putAll(argThat(new ArgumentMatcher>() { + + @Override + public boolean matches(List argument) { + List arg = (argument); + LOG.info("size: " + arg.size()); + return arg.size() == 0; + } + })); + } + + public void testCacheSelection() throws IOException, InterruptedException { + setUpSearchResponse(); + setUpTransportInterceptor(this::entityResultHandler); + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + .thenReturn(Optional.of(testNodes[1].discoveryNode())); + + List hotEntities = new ArrayList<>(); + Entity entity1 = Entity.createEntityByReordering(attrs1); + hotEntities.add(entity1); + + List coldEntities = new ArrayList<>(); + Entity entity2 = Entity.createEntityByReordering(attrs2); + coldEntities.add(entity2); + + provider = mock(CacheProvider.class); + entityCache = mock(EntityCache.class); + when(provider.get()).thenReturn(entityCache); + when(entityCache.selectUpdateCandidate(any(), any(), any())).thenReturn(Pair.of(hotEntities, coldEntities)); + when(entityCache.get(any(), any())).thenReturn(null); + + new EntityResultTransportAction( + new ActionFilters(Collections.emptySet()), + // since we send requests to testNodes[1] + testNodes[1].transportService, + normalModelManager, + adCircuitBreakerService, + provider, + stateManager, + indexUtil, + resultWriteQueue, + checkpointReadQueue, + coldEntityQueue, + threadPool, + entityColdStartQueue, + adStats + ); + + CountDownLatch modelNodeInProgress = new CountDownLatch(1); + doAnswer(invocation -> { + if (modelNodeInProgress.getCount() == 1) { + modelNodeInProgress.countDown(); + } + return null; + }).when(coldEntityQueue).putAll(any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + assertTrue(modelNodeInProgress.await(10000L, TimeUnit.MILLISECONDS)); + verify(checkpointReadQueue).putAll(argThat(new ArgumentMatcher>() { + + @Override + public boolean matches(List argument) { + List arg = (argument); + LOG.info("size: " + arg.size() + " ; element: " + arg.get(0)); + return arg.size() == 1 && arg.get(0).getEntity().equals(entity1); + } + })); + + verify(coldEntityQueue).putAll(argThat(new ArgumentMatcher>() { + + @Override + public boolean matches(List argument) { + List arg = (argument); + LOG.info("size: " + arg.size() + " ; element: " + arg.get(0)); + return arg.size() == 1 && arg.get(0).getEntity().equals(entity2); + } + })); + } + + public void testNullFeatures() throws InterruptedException { + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + CompositeAggregation emptyComposite = mock(CompositeAggregation.class); + when(emptyComposite.getName()).thenReturn(null); + when(emptyComposite.afterKey()).thenReturn(null); + // empty bucket + when(emptyComposite.getBuckets()) + .thenAnswer((Answer>) invocation -> { return new ArrayList(); }); + Aggregations emptyAggs = new Aggregations(Collections.singletonList(emptyComposite)); + SearchResponseSections emptySections = new SearchResponseSections(SearchHits.empty(), emptyAggs, null, false, null, null, 1); + SearchResponse nullResponse = new SearchResponse(emptySections, null, 1, 1, 0, 0, ShardSearchFailure.EMPTY_ARRAY, Clusters.EMPTY); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(nullResponse); + inProgressLatch.countDown(); + return null; + }).when(client).search(any(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.01); + + assertTrue(inProgressLatch.await(10000L, TimeUnit.MILLISECONDS)); + + PlainActionFuture listener2 = new PlainActionFuture<>(); + action.doExecute(null, request, listener2); + + AnomalyResultResponse response2 = listener2.actionGet(10000L); + assertEquals(Double.NaN, response2.getAnomalyGrade(), 0.01); + } + + // empty page but non-null after key will make the CompositeRetriever.PageIterator retry + public void testRetry() throws IOException, InterruptedException { + detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(serviceField, hostField)); + + // set up empty page but non-null after key + CompositeAggregation emptyNonNullComposite = mock(CompositeAggregation.class); + when(emptyNonNullComposite.getName()).thenReturn(CompositeRetriever.AGG_NAME_COMP); + when(emptyNonNullComposite.afterKey()).thenReturn(attrs3); + + List emptyNonNullCompositeBuckets = new ArrayList<>(); + when(emptyNonNullComposite.getBuckets()) + .thenAnswer((Answer>) invocation -> { return emptyNonNullCompositeBuckets; }); + + Aggregations emptyNonNullAggs = new Aggregations(Collections.singletonList(emptyNonNullComposite)); + + SearchResponseSections emptyNonNullSections = new SearchResponseSections( + SearchHits.empty(), + emptyNonNullAggs, + null, + false, + null, + null, + 1 + ); + SearchResponse emptyNonNullResponse = new SearchResponse( + emptyNonNullSections, + null, + 1, + 1, + 0, + 0, + ShardSearchFailure.EMPTY_ARRAY, + Clusters.EMPTY + ); + + // set up a non-empty response + CompositeAggregation composite = mock(CompositeAggregation.class); + when(composite.getName()).thenReturn(CompositeRetriever.AGG_NAME_COMP); + when(composite.afterKey()).thenReturn(attrs1); + + String featureID = detector.getFeatureAttributes().get(0).getId(); + List compositeBuckets = new ArrayList<>(); + CompositeAggregation.Bucket bucket = mock(CompositeAggregation.Bucket.class); + when(bucket.getKey()).thenReturn(attrs1); + List aggList = new ArrayList<>(); + aggList.add(new InternalMin(featureID, randomDouble(), DocValueFormat.RAW, new HashMap<>())); + Aggregations aggregations = new Aggregations(aggList); + when(bucket.getAggregations()).thenReturn(aggregations); + compositeBuckets.add(bucket); + + when(composite.getBuckets()).thenAnswer((Answer>) invocation -> { return compositeBuckets; }); + Aggregations aggs = new Aggregations(Collections.singletonList(composite)); + + SearchResponseSections sections = new SearchResponseSections(SearchHits.empty(), aggs, null, false, null, null, 1); + SearchResponse nonEmptyResponse = new SearchResponse(sections, null, 1, 1, 0, 0, ShardSearchFailure.EMPTY_ARRAY, Clusters.EMPTY); + + // set up an empty response + SearchResponse emptyResponse = createEmptyResponse(); + + CountDownLatch coordinatingNodeinProgress = new CountDownLatch(3); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + if (coordinatingNodeinProgress.getCount() == 3) { + coordinatingNodeinProgress.countDown(); + listener.onResponse(emptyNonNullResponse); + } else if (coordinatingNodeinProgress.getCount() == 2) { + coordinatingNodeinProgress.countDown(); + listener.onResponse(nonEmptyResponse); + } else { + coordinatingNodeinProgress.countDown(); + listener.onResponse(emptyResponse); + } + return null; + }).when(client).search(any(), any()); + + // only the EntityResultRequest from nonEmptyResponse will reach model node + CountDownLatch modelNodeInProgress = new CountDownLatch(1); + doAnswer(invocation -> { + if (modelNodeInProgress.getCount() == 1) { + modelNodeInProgress.countDown(); + } + return null; + }).when(coldEntityQueue).putAll(any()); + + setUpTransportInterceptor(this::entityResultHandler); + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + .thenReturn(Optional.of(testNodes[1].discoveryNode())); + setUpEntityResult(1); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.01); + + // since coordinating node and model node run in async model (i.e., coordinating node + // does not need sync response to proceed next page, we have to make sure both + // coordinating node and model node finishes before checking assertions) + assertTrue(coordinatingNodeinProgress.await(10000L, TimeUnit.MILLISECONDS)); + assertTrue(modelNodeInProgress.await(10000L, TimeUnit.MILLISECONDS)); + + // since we have 3 results in the first page + verify(resultWriteQueue, times(1)).put(any()); + } + + public void testPageToString() { + CompositeRetriever retriever = new CompositeRetriever( + 0, + 10, + detector, + xContentRegistry(), + client, + clientUtil, + 100, + clock, + settings, + 10000, + 1000, + indexNameResolver, + clusterService + ); + Map results = new HashMap<>(); + Entity entity1 = Entity.createEntityByReordering(attrs1); + double[] val = new double[1]; + val[0] = 3.0; + results.put(entity1, val); + CompositeRetriever.Page page = retriever.new Page(results); + String repr = page.toString(); + assertTrue("actual:" + repr, repr.contains(app0)); + assertTrue("actual:" + repr, repr.contains(server1)); + } + + public void testEmptyPageToString() { + CompositeRetriever retriever = new CompositeRetriever( + 0, + 10, + detector, + xContentRegistry(), + client, + clientUtil, + 100, + clock, + settings, + 10000, + 1000, + indexNameResolver, + clusterService + ); + + CompositeRetriever.Page page = retriever.new Page(null); + String repr = page.toString(); + // we have at least class name + assertTrue("actual:" + repr, repr.contains("Page")); + } + + @SuppressWarnings("unchecked") + private NodeStateManager setUpTestExceptionTestingInModelNode() throws IOException { + setUpSearchResponse(); + setUpTransportInterceptor(this::entityResultHandler); + // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + .thenReturn(Optional.of(testNodes[1].discoveryNode())); + + NodeStateManager modelNodeStateManager = mock(NodeStateManager.class); + CountDownLatch modelNodeInProgress = new CountDownLatch(1); + // make sure parameters are not null, otherwise this mock won't get invoked + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + modelNodeInProgress.countDown(); + return null; + }).when(modelNodeStateManager).getAnomalyDetector(anyString(), any(ActionListener.class)); + return modelNodeStateManager; + } + + public void testEndRunNowInModelNode() throws InterruptedException, IOException { + NodeStateManager modelNodeStateManager = setUpTestExceptionTestingInModelNode(); + + CountDownLatch inProgress = new CountDownLatch(1); + doAnswer(invocation -> { + inProgress.countDown(); + return Optional + .of( + new EndRunException( + detectorId, + CommonMessages.INVALID_SEARCH_QUERY_MSG, + new NoSuchElementException("No value present"), + true + ) + ); + }).when(modelNodeStateManager).fetchExceptionAndClear(anyString()); + + when(modelNodeStateManager.fetchExceptionAndClear(anyString())) + .thenReturn( + Optional + .of( + new EndRunException( + detectorId, + CommonMessages.INVALID_SEARCH_QUERY_MSG, + new NoSuchElementException("No value present"), + true + ) + ) + ); + + setUpEntityResult(1, modelNodeStateManager); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.01); + + assertTrue(inProgress.await(10000L, TimeUnit.MILLISECONDS)); + + // since it is end run now, we don't expect any of the normal workflow continues + verify(resultWriteQueue, never()).put(any()); + } + + public void testEndRunNowFalseInModelNode() throws InterruptedException, IOException { + NodeStateManager modelNodeStateManager = setUpTestExceptionTestingInModelNode(); + + when(modelNodeStateManager.fetchExceptionAndClear(anyString())) + .thenReturn( + Optional + .of( + new EndRunException( + detectorId, + CommonMessages.INVALID_SEARCH_QUERY_MSG, + new NoSuchElementException("No value present"), + false + ) + ) + ); + + setUpEntityResult(1, modelNodeStateManager); + + CountDownLatch inProgress = new CountDownLatch(1); + doAnswer(invocation -> { + if (inProgress.getCount() == 1) { + inProgress.countDown(); + } + return null; + }).when(stateManager).setException(anyString(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.01); + + assertTrue(inProgress.await(10000L, TimeUnit.MILLISECONDS)); + + // since it is end run now = false, the normal workflow continues + verify(resultWriteQueue, times(3)).put(any()); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(stateManager).setException(anyString(), exceptionCaptor.capture()); + EndRunException endRunException = (EndRunException) (exceptionCaptor.getValue()); + assertTrue(!endRunException.isEndNow()); + } + + /** + * Test that in model node, previously recorded exception is OpenSearchTimeoutException, + * @throws IOException when failing to set up transport layer + * @throws InterruptedException when failing to wait for inProgress to finish + */ + public void testTimeOutExceptionInModelNode() throws IOException, InterruptedException { + NodeStateManager modelNodeStateManager = setUpTestExceptionTestingInModelNode(); + + when(modelNodeStateManager.fetchExceptionAndClear(anyString())).thenReturn(Optional.of(new OpenSearchTimeoutException("blah"))); + + CountDownLatch inProgress = new CountDownLatch(1); + doAnswer(invocation -> { + inProgress.countDown(); + return null; + }).when(stateManager).setException(anyString(), any(Exception.class)); + + setUpEntityResult(1, modelNodeStateManager); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.01); + + assertTrue(inProgress.await(10000L, TimeUnit.MILLISECONDS)); + + // since OpenSearchTimeoutException is not end run exception (now = true), the normal workflow continues + verify(resultWriteQueue, times(3)).put(any()); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(stateManager).setException(anyString(), exceptionCaptor.capture()); + Exception actual = exceptionCaptor.getValue(); + assertTrue("actual exception is " + actual, actual instanceof InternalFailure); + } + + /** + * Test that when both previous and current run returns exception, we return more + * important exception (EndRunException is more important) + * @throws InterruptedException when failing to wait for inProgress to finish + * @throws IOException when failing to set up transport layer + */ + public void testSelectHigherExceptionInModelNode() throws InterruptedException, IOException { + when(entityCache.get(any(), any())).thenThrow(EndRunException.class); + + NodeStateManager modelNodeStateManager = setUpTestExceptionTestingInModelNode(); + + when(modelNodeStateManager.fetchExceptionAndClear(anyString())).thenReturn(Optional.of(new OpenSearchTimeoutException("blah"))); + + setUpEntityResult(1, modelNodeStateManager); + + CountDownLatch inProgress = new CountDownLatch(1); + doAnswer(invocation -> { + inProgress.countDown(); + return null; + }).when(stateManager).setException(anyString(), any(Exception.class)); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.01); + + assertTrue(inProgress.await(10000L, TimeUnit.MILLISECONDS)); + + // since EndRunException is thrown before getting any result, we cannot save anything + verify(resultWriteQueue, never()).put(any()); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(stateManager).setException(anyString(), exceptionCaptor.capture()); + EndRunException endRunException = (EndRunException) (exceptionCaptor.getValue()); + assertTrue(!endRunException.isEndNow()); + } + + /** + * A missing index will cause the search result to contain null aggregation + * like {"took":0,"timed_out":false,"_shards":{"total":0,"successful":0,"skipped":0,"failed":0},"hits":{"max_score":0.0,"hits":[]}} + * + * The test verifies we can handle such situation and won't throw exceptions + * @throws InterruptedException while waiting for execution gets interruptted + */ + public void testMissingIndex() throws InterruptedException { + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener + .onResponse( + new SearchResponse( + new SearchResponseSections(SearchHits.empty(), null, null, false, null, null, 1), + null, + 1, + 1, + 0, + 0, + ShardSearchFailure.EMPTY_ARRAY, + Clusters.EMPTY + ) + ); + inProgressLatch.countDown(); + return null; + }).when(client).search(any(), any()); + + PlainActionFuture listener = new PlainActionFuture<>(); + + action.doExecute(null, request, listener); + + AnomalyResultResponse response = listener.actionGet(10000L); + assertEquals(Double.NaN, response.getAnomalyGrade(), 0.01); + + assertTrue(inProgressLatch.await(10000L, TimeUnit.MILLISECONDS)); + verify(stateManager, times(1)).setException(eq(detectorId), any(EndRunException.class)); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorActionTests.java b/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorActionTests.java index a838090e1..73c67fe79 100644 --- a/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorActionTests.java @@ -18,8 +18,8 @@ import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; -import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.timeseries.TestHelpers; diff --git a/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorActionTests.java-e b/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorActionTests.java-e new file mode 100644 index 000000000..73c67fe79 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorActionTests.java-e @@ -0,0 +1,74 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.time.Instant; + +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.timeseries.TestHelpers; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +public class PreviewAnomalyDetectorActionTests extends OpenSearchSingleNodeTestCase { + + @Override + protected NamedWriteableRegistry writableRegistry() { + return getInstanceFromNode(NamedWriteableRegistry.class); + } + + @Test + public void testPreviewRequest() throws Exception { + BytesStreamOutput out = new BytesStreamOutput(); + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of("testKey", "testValue"), Instant.now()); + PreviewAnomalyDetectorRequest request = new PreviewAnomalyDetectorRequest( + detector, + "1234", + Instant.now().minusSeconds(60), + Instant.now() + ); + request.writeTo(out); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), writableRegistry()); + PreviewAnomalyDetectorRequest newRequest = new PreviewAnomalyDetectorRequest(input); + Assert.assertEquals(request.getId(), newRequest.getId()); + Assert.assertEquals(request.getStartTime(), newRequest.getStartTime()); + Assert.assertEquals(request.getEndTime(), newRequest.getEndTime()); + Assert.assertNotNull(newRequest.getDetector()); + Assert.assertNull(newRequest.validate()); + } + + @Test + public void testPreviewResponse() throws Exception { + BytesStreamOutput out = new BytesStreamOutput(); + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of("testKey", "testValue"), Instant.now()); + AnomalyResult result = TestHelpers.randomHCADAnomalyDetectResult(0.8d, 0d); + PreviewAnomalyDetectorResponse response = new PreviewAnomalyDetectorResponse(ImmutableList.of(result), detector); + response.writeTo(out); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), writableRegistry()); + PreviewAnomalyDetectorResponse newResponse = new PreviewAnomalyDetectorResponse(input); + Assert.assertNotNull(newResponse.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + } + + @Test + public void testPreviewAction() throws Exception { + Assert.assertNotNull(PreviewAnomalyDetectorAction.INSTANCE.name()); + Assert.assertEquals(PreviewAnomalyDetectorAction.INSTANCE.name(), PreviewAnomalyDetectorAction.NAME); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportActionTests.java index f2cda2046..a60a350e6 100644 --- a/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportActionTests.java @@ -66,9 +66,9 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.ConfigConstants; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.rest.RestStatus; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.threadpool.ThreadPool; diff --git a/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportActionTests.java-e b/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportActionTests.java-e new file mode 100644 index 000000000..bd6d01c2c --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportActionTests.java-e @@ -0,0 +1,414 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyObject; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.ad.AnomalyDetectorRunner; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.feature.Features; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableMap; + +public class PreviewAnomalyDetectorTransportActionTests extends OpenSearchSingleNodeTestCase { + private ActionListener response; + private PreviewAnomalyDetectorTransportAction action; + private AnomalyDetectorRunner runner; + private ClusterService clusterService; + private FeatureManager featureManager; + private ModelManager modelManager; + private Task task; + private ADCircuitBreakerService circuitBreaker; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + task = mock(Task.class); + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays + .asList( + AnomalyDetectorSettings.MAX_ANOMALY_FEATURES, + AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES, + AnomalyDetectorSettings.PAGE_SIZE, + AnomalyDetectorSettings.MAX_CONCURRENT_PREVIEW + ) + ) + ) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + ClusterName clusterName = new ClusterName("test"); + Settings indexSettings = Settings + .builder() + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + .build(); + final Settings.Builder existingSettings = Settings.builder().put(indexSettings).put(IndexMetadata.SETTING_INDEX_UUID, "test2UUID"); + IndexMetadata indexMetaData = IndexMetadata.builder(CommonName.CONFIG_INDEX).settings(existingSettings).build(); + final Map indices = new HashMap<>(); + indices.put(CommonName.CONFIG_INDEX, indexMetaData); + ClusterState clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().indices(indices).build()).build(); + when(clusterService.state()).thenReturn(clusterState); + + featureManager = mock(FeatureManager.class); + modelManager = mock(ModelManager.class); + runner = new AnomalyDetectorRunner(modelManager, featureManager, AnomalyDetectorSettings.MAX_PREVIEW_RESULTS); + circuitBreaker = mock(ADCircuitBreakerService.class); + when(circuitBreaker.isOpen()).thenReturn(false); + action = new PreviewAnomalyDetectorTransportAction( + Settings.EMPTY, + mock(TransportService.class), + clusterService, + mock(ActionFilters.class), + client(), + runner, + xContentRegistry(), + circuitBreaker + ); + } + + @SuppressWarnings("unchecked") + @Test + public void testPreviewTransportAction() throws IOException, InterruptedException { + final CountDownLatch inProgressLatch = new CountDownLatch(1); + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of("testKey", "testValue"), Instant.now()); + PreviewAnomalyDetectorRequest request = new PreviewAnomalyDetectorRequest(detector, detector.getId(), Instant.now(), Instant.now()); + ActionListener previewResponse = new ActionListener() { + @Override + public void onResponse(PreviewAnomalyDetectorResponse response) { + try { + XContentBuilder previewBuilder = response.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS); + Assert.assertNotNull(previewBuilder); + Map map = TestHelpers.XContentBuilderToMap(previewBuilder); + List results = (List) map.get("anomaly_result"); + Assert.assertNotNull(results); + Assert.assertTrue(results.size() > 0); + inProgressLatch.countDown(); + } catch (IOException e) { + // Should not reach here + Assert.assertTrue(false); + } + } + + @Override + public void onFailure(Exception e) { + // onFailure should not be called + Assert.assertTrue(false); + } + }; + + doReturn(TestHelpers.randomThresholdingResults()).when(modelManager).getPreviewResults(any(), anyInt()); + + doAnswer(responseMock -> { + Long startTime = responseMock.getArgument(1); + ActionListener listener = responseMock.getArgument(3); + listener.onResponse(TestHelpers.randomFeatures()); + return null; + }).when(featureManager).getPreviewFeatures(anyObject(), anyLong(), anyLong(), any()); + action.doExecute(task, request, previewResponse); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + @Test + public void testPreviewTransportActionWithNoFeature() throws IOException, InterruptedException { + // Detector with no feature, Preview should fail + final CountDownLatch inProgressLatch = new CountDownLatch(1); + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(Collections.emptyList()); + PreviewAnomalyDetectorRequest request = new PreviewAnomalyDetectorRequest(detector, detector.getId(), Instant.now(), Instant.now()); + ActionListener previewResponse = new ActionListener() { + @Override + public void onResponse(PreviewAnomalyDetectorResponse response) { + Assert.assertTrue(false); + } + + @Override + public void onFailure(Exception e) { + Assert.assertTrue(e.getMessage().contains("Can't preview detector without feature")); + inProgressLatch.countDown(); + } + }; + action.doExecute(task, request, previewResponse); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + @Test + public void testPreviewTransportActionWithNoDetector() throws IOException, InterruptedException { + // When detectorId is null, preview should fail + final CountDownLatch inProgressLatch = new CountDownLatch(1); + PreviewAnomalyDetectorRequest request = new PreviewAnomalyDetectorRequest(null, "123", Instant.now(), Instant.now()); + ActionListener previewResponse = new ActionListener() { + @Override + public void onResponse(PreviewAnomalyDetectorResponse response) { + Assert.assertTrue(false); + } + + @Override + public void onFailure(Exception e) { + Assert.assertTrue(e.getMessage().contains("Could not execute get query to find detector")); + inProgressLatch.countDown(); + } + }; + action.doExecute(task, request, previewResponse); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + @Test + public void testPreviewTransportActionWithDetectorID() throws IOException, InterruptedException { + // When AD index does not exist, cannot query the detector + final CountDownLatch inProgressLatch = new CountDownLatch(1); + PreviewAnomalyDetectorRequest request = new PreviewAnomalyDetectorRequest(null, "1234", Instant.now(), Instant.now()); + ActionListener previewResponse = new ActionListener() { + @Override + public void onResponse(PreviewAnomalyDetectorResponse response) { + Assert.assertTrue(false); + } + + @Override + public void onFailure(Exception e) { + Assert.assertTrue(e.getMessage().contains("Could not execute get query to find detector")); + inProgressLatch.countDown(); + } + }; + action.doExecute(task, request, previewResponse); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + @Test + public void testPreviewTransportActionWithIndex() throws IOException, InterruptedException { + // When AD index exists, and detector does not exist + final CountDownLatch inProgressLatch = new CountDownLatch(1); + PreviewAnomalyDetectorRequest request = new PreviewAnomalyDetectorRequest(null, "1234", Instant.now(), Instant.now()); + Settings indexSettings = Settings.builder().put("index.number_of_shards", 5).put("index.number_of_replicas", 1).build(); + CreateIndexRequest indexRequest = new CreateIndexRequest(CommonName.CONFIG_INDEX, indexSettings); + client().admin().indices().create(indexRequest).actionGet(); + ActionListener previewResponse = new ActionListener() { + @Override + public void onResponse(PreviewAnomalyDetectorResponse response) { + Assert.assertTrue(false); + } + + @Override + public void onFailure(Exception e) { + Assert.assertTrue(e.getMessage().contains("Can't find anomaly detector with id:1234")); + inProgressLatch.countDown(); + } + }; + action.doExecute(task, request, previewResponse); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + @SuppressWarnings("unchecked") + @Test + public void testPreviewTransportActionNoContext() throws IOException, InterruptedException { + final CountDownLatch inProgressLatch = new CountDownLatch(1); + Settings settings = Settings.builder().put(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.getKey(), true).build(); + Client client = mock(Client.class); + ThreadContext threadContext = new ThreadContext(settings); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alice|odfe,aes|engineering,operations"); + org.opensearch.threadpool.ThreadPool mockThreadPool = mock(ThreadPool.class); + when(client.threadPool()).thenReturn(mockThreadPool); + when(mockThreadPool.getThreadContext()).thenReturn(threadContext); + PreviewAnomalyDetectorTransportAction previewAction = new PreviewAnomalyDetectorTransportAction( + settings, + mock(TransportService.class), + clusterService, + mock(ActionFilters.class), + client, + runner, + xContentRegistry(), + circuitBreaker + ); + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of("testKey", "testValue"), Instant.now()); + PreviewAnomalyDetectorRequest request = new PreviewAnomalyDetectorRequest(detector, detector.getId(), Instant.now(), Instant.now()); + + GetResponse getDetectorResponse = TestHelpers.createGetResponse(detector, detector.getId(), CommonName.CONFIG_INDEX); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length == 2 + ); + + assertTrue(args[0] instanceof GetRequest); + assertTrue(args[1] instanceof ActionListener); + + ActionListener listener = (ActionListener) args[1]; + listener.onResponse(getDetectorResponse); + return null; + }).when(client).get(any(GetRequest.class), any()); + + ActionListener responseActionListener = new ActionListener() { + @Override + public void onResponse(PreviewAnomalyDetectorResponse response) { + Assert.assertTrue(false); + } + + @Override + public void onFailure(Exception e) { + Assert.assertEquals(OpenSearchStatusException.class, e.getClass()); + inProgressLatch.countDown(); + } + }; + previewAction.doExecute(task, request, responseActionListener); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + @SuppressWarnings("unchecked") + @Test + public void testPreviewTransportActionWithDetector() throws IOException, InterruptedException { + final CountDownLatch inProgressLatch = new CountDownLatch(1); + CreateIndexResponse createResponse = TestHelpers + .createIndex(client().admin(), CommonName.CONFIG_INDEX, ADIndexManagement.getConfigMappings()); + Assert.assertNotNull(createResponse); + + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of("testKey", "testValue"), Instant.now()); + IndexRequest indexRequest = new IndexRequest(CommonName.CONFIG_INDEX) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .source(detector.toXContent(XContentFactory.jsonBuilder(), RestHandlerUtils.XCONTENT_WITH_TYPE)); + IndexResponse indexResponse = client().index(indexRequest).actionGet(5_000); + assertEquals(RestStatus.CREATED, indexResponse.status()); + + PreviewAnomalyDetectorRequest request = new PreviewAnomalyDetectorRequest( + null, + indexResponse.getId(), + Instant.now(), + Instant.now() + ); + ActionListener previewResponse = new ActionListener() { + @Override + public void onResponse(PreviewAnomalyDetectorResponse response) { + try { + XContentBuilder previewBuilder = response.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS); + Assert.assertNotNull(previewBuilder); + Map map = TestHelpers.XContentBuilderToMap(previewBuilder); + List results = (List) map.get("anomaly_result"); + Assert.assertNotNull(results); + Assert.assertTrue(results.size() > 0); + inProgressLatch.countDown(); + } catch (IOException e) { + // Should not reach here + Assert.assertTrue(false); + } + } + + @Override + public void onFailure(Exception e) { + // onFailure should not be called + Assert.assertTrue(false); + } + }; + doReturn(TestHelpers.randomThresholdingResults()).when(modelManager).getPreviewResults(any(), anyInt()); + + doAnswer(responseMock -> { + Long startTime = responseMock.getArgument(1); + ActionListener listener = responseMock.getArgument(3); + listener.onResponse(TestHelpers.randomFeatures()); + return null; + }).when(featureManager).getPreviewFeatures(anyObject(), anyLong(), anyLong(), any()); + action.doExecute(task, request, previewResponse); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + @Test + public void testCircuitBreakerOpen() throws IOException, InterruptedException { + // preview has no detector id + AnomalyDetector detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(null, Arrays.asList("a")); + PreviewAnomalyDetectorRequest request = new PreviewAnomalyDetectorRequest(detector, detector.getId(), Instant.now(), Instant.now()); + + when(circuitBreaker.isOpen()).thenReturn(true); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + ActionListener previewResponse = new ActionListener() { + @Override + public void onResponse(PreviewAnomalyDetectorResponse response) { + Assert.assertTrue(false); + } + + @Override + public void onFailure(Exception e) { + Assert.assertTrue("actual class: " + e.getClass(), e instanceof OpenSearchStatusException); + Assert.assertTrue(e.getMessage().contains(CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG)); + inProgressLatch.countDown(); + } + }; + action.doExecute(task, request, previewResponse); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ProfileITTests.java b/src/test/java/org/opensearch/ad/transport/ProfileITTests.java index e9aac6377..013f00097 100644 --- a/src/test/java/org/opensearch/ad/transport/ProfileITTests.java +++ b/src/test/java/org/opensearch/ad/transport/ProfileITTests.java @@ -16,20 +16,20 @@ import java.util.HashSet; import java.util.concurrent.ExecutionException; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.model.DetectorProfileName; import org.opensearch.plugins.Plugin; import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; public class ProfileITTests extends OpenSearchIntegTestCase { @Override protected Collection> nodePlugins() { - return Collections.singletonList(AnomalyDetectorPlugin.class); + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); } protected Collection> transportClientPlugins() { - return Collections.singletonList(AnomalyDetectorPlugin.class); + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); } public void testNormalProfile() throws ExecutionException, InterruptedException { diff --git a/src/test/java/org/opensearch/ad/transport/ProfileITTests.java-e b/src/test/java/org/opensearch/ad/transport/ProfileITTests.java-e new file mode 100644 index 000000000..013f00097 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ProfileITTests.java-e @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.concurrent.ExecutionException; + +import org.opensearch.ad.model.DetectorProfileName; +import org.opensearch.plugins.Plugin; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +public class ProfileITTests extends OpenSearchIntegTestCase { + + @Override + protected Collection> nodePlugins() { + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); + } + + protected Collection> transportClientPlugins() { + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); + } + + public void testNormalProfile() throws ExecutionException, InterruptedException { + ProfileRequest profileRequest = new ProfileRequest("123", new HashSet(), false); + + ProfileResponse response = client().execute(ProfileAction.INSTANCE, profileRequest).get(); + assertTrue("getting profile failed", !response.hasFailures()); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ProfileTests.java b/src/test/java/org/opensearch/ad/transport/ProfileTests.java index affa5d8a1..ff33ed277 100644 --- a/src/test/java/org/opensearch/ad/transport/ProfileTests.java +++ b/src/test/java/org/opensearch/ad/transport/ProfileTests.java @@ -37,8 +37,8 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.transport.TransportAddress; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.test.OpenSearchTestCase; diff --git a/src/test/java/org/opensearch/ad/transport/ProfileTests.java-e b/src/test/java/org/opensearch/ad/transport/ProfileTests.java-e new file mode 100644 index 000000000..a7918f075 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ProfileTests.java-e @@ -0,0 +1,272 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; +import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.Version; +import org.opensearch.action.FailedNodeException; +import org.opensearch.ad.common.exception.JsonPathNotFoundException; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.DetectorProfileName; +import org.opensearch.ad.model.ModelProfileOnNode; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.Strings; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.transport.TransportAddress; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.constant.CommonName; + +import test.org.opensearch.ad.util.JsonDeserializer; + +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; + +public class ProfileTests extends OpenSearchTestCase { + String node1, nodeName1, clusterName; + String node2, nodeName2; + Map clusterStats; + DiscoveryNode discoveryNode1, discoveryNode2; + long modelSize; + String model1Id; + String model0Id; + String detectorId; + int shingleSize; + Map modelSizeMap1, modelSizeMap2; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + clusterName = "test-cluster-name"; + + node1 = "node1"; + nodeName1 = "nodename1"; + discoveryNode1 = new DiscoveryNode( + nodeName1, + node1, + new TransportAddress(TransportAddress.META_ADDRESS, 9300), + emptyMap(), + emptySet(), + Version.CURRENT + ); + + node2 = "node2"; + nodeName2 = "nodename2"; + discoveryNode2 = new DiscoveryNode( + nodeName2, + node2, + new TransportAddress(TransportAddress.META_ADDRESS, 9301), + emptyMap(), + emptySet(), + Version.CURRENT + ); + + clusterStats = new HashMap<>(); + + modelSize = 4456448L; + model1Id = "Pl536HEBnXkDrah03glg_model_rcf_1"; + model0Id = "Pl536HEBnXkDrah03glg_model_rcf_0"; + detectorId = "123"; + shingleSize = 6; + + modelSizeMap1 = new HashMap() { + { + put(model1Id, modelSize); + } + }; + + modelSizeMap2 = new HashMap() { + { + put(model0Id, modelSize); + } + }; + } + + @Test + public void testProfileNodeRequest() throws IOException { + + Set profilesToRetrieve = new HashSet(); + profilesToRetrieve.add(DetectorProfileName.COORDINATING_NODE); + ProfileRequest ProfileRequest = new ProfileRequest(detectorId, profilesToRetrieve, false); + ProfileNodeRequest ProfileNodeRequest = new ProfileNodeRequest(ProfileRequest); + assertEquals("ProfileNodeRequest has the wrong detector id", ProfileNodeRequest.getId(), detectorId); + assertEquals("ProfileNodeRequest has the wrong ProfileRequest", ProfileNodeRequest.getProfilesToBeRetrieved(), profilesToRetrieve); + + // Test serialization + BytesStreamOutput output = new BytesStreamOutput(); + ProfileNodeRequest.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + ProfileNodeRequest nodeRequest = new ProfileNodeRequest(streamInput); + assertEquals("serialization has the wrong detector id", nodeRequest.getId(), detectorId); + assertEquals("serialization has the wrong ProfileRequest", nodeRequest.getProfilesToBeRetrieved(), profilesToRetrieve); + + } + + @Test + public void testProfileNodeResponse() throws IOException, JsonPathNotFoundException { + + // Test serialization + ProfileNodeResponse profileNodeResponse = new ProfileNodeResponse( + discoveryNode1, + modelSizeMap1, + shingleSize, + 0, + 0, + new ArrayList<>(), + modelSizeMap1.size() + ); + BytesStreamOutput output = new BytesStreamOutput(); + profileNodeResponse.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + ProfileNodeResponse readResponse = ProfileNodeResponse.readProfiles(streamInput); + assertEquals("serialization has the wrong model size", readResponse.getModelSize(), profileNodeResponse.getModelSize()); + assertEquals("serialization has the wrong shingle size", readResponse.getShingleSize(), profileNodeResponse.getShingleSize()); + + // Test toXContent + XContentBuilder builder = jsonBuilder(); + profileNodeResponse.toXContent(builder.startObject(), ToXContent.EMPTY_PARAMS).endObject(); + String json = Strings.toString(builder); + + for (Map.Entry profile : modelSizeMap1.entrySet()) { + assertEquals( + "toXContent has the wrong model size", + JsonDeserializer.getLongValue(json, CommonName.MODEL_SIZE_IN_BYTES, profile.getKey()), + profile.getValue().longValue() + ); + } + + assertEquals("toXContent has the wrong shingle size", JsonDeserializer.getIntValue(json, ADCommonName.SHINGLE_SIZE), shingleSize); + } + + @Test + public void testProfileRequest() throws IOException { + String detectorId = "123"; + Set profilesToRetrieve = new HashSet(); + profilesToRetrieve.add(DetectorProfileName.COORDINATING_NODE); + ProfileRequest profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, false); + + // Test Serialization + BytesStreamOutput output = new BytesStreamOutput(); + profileRequest.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + ProfileRequest readRequest = new ProfileRequest(streamInput); + assertEquals( + "Serialization has the wrong profiles to be retrieved", + readRequest.getProfilesToBeRetrieved(), + profileRequest.getProfilesToBeRetrieved() + ); + assertEquals("Serialization has the wrong detector id", readRequest.getId(), profileRequest.getId()); + } + + @Test + public void testProfileResponse() throws IOException, JsonPathNotFoundException { + + ProfileNodeResponse profileNodeResponse1 = new ProfileNodeResponse( + discoveryNode1, + modelSizeMap1, + shingleSize, + 0, + 0, + new ArrayList<>(), + modelSizeMap1.size() + ); + ProfileNodeResponse profileNodeResponse2 = new ProfileNodeResponse( + discoveryNode2, + modelSizeMap2, + -1, + 0, + 0, + new ArrayList<>(), + modelSizeMap2.size() + ); + List profileNodeResponses = Arrays.asList(profileNodeResponse1, profileNodeResponse2); + List failures = Collections.emptyList(); + ProfileResponse profileResponse = new ProfileResponse(new ClusterName(clusterName), profileNodeResponses, failures); + + assertEquals(node1, profileResponse.getCoordinatingNode()); + assertEquals(shingleSize, profileResponse.getShingleSize()); + assertEquals(modelSize * 2, profileResponse.getTotalSizeInBytes()); + assertEquals(2, profileResponse.getModelProfile().length); + for (ModelProfileOnNode profile : profileResponse.getModelProfile()) { + assertTrue(node1.equals(profile.getNodeId()) || node2.equals(profile.getNodeId())); + assertEquals(modelSize, profile.getModelSize()); + if (node1.equals(profile.getNodeId())) { + assertEquals(model1Id, profile.getModelId()); + } + if (node2.equals(profile.getNodeId())) { + assertEquals(model0Id, profile.getModelId()); + } + } + + // Test toXContent + XContentBuilder builder = jsonBuilder(); + profileResponse.toXContent(builder.startObject(), ToXContent.EMPTY_PARAMS).endObject(); + String json = Strings.toString(builder); + + logger.info("JSON: " + json); + + assertEquals( + "toXContent has the wrong coordinating node", + node1, + JsonDeserializer.getTextValue(json, ProfileResponse.COORDINATING_NODE) + ); + assertEquals( + "toXContent has the wrong shingle size", + shingleSize, + JsonDeserializer.getLongValue(json, ProfileResponse.SHINGLE_SIZE) + ); + assertEquals("toXContent has the wrong total size", modelSize * 2, JsonDeserializer.getLongValue(json, ProfileResponse.TOTAL_SIZE)); + + JsonArray modelsJson = JsonDeserializer.getArrayValue(json, ProfileResponse.MODELS); + + for (int i = 0; i < modelsJson.size(); i++) { + JsonElement element = modelsJson.get(i); + assertTrue( + "toXContent has the wrong model id", + JsonDeserializer.getTextValue(element, CommonName.MODEL_ID_FIELD).equals(model1Id) + || JsonDeserializer.getTextValue(element, CommonName.MODEL_ID_FIELD).equals(model0Id) + ); + + assertEquals( + "toXContent has the wrong model size", + JsonDeserializer.getLongValue(element, CommonName.MODEL_SIZE_IN_BYTES), + modelSize + ); + + if (JsonDeserializer.getTextValue(element, CommonName.MODEL_ID_FIELD).equals(model1Id)) { + assertEquals("toXContent has the wrong node id", JsonDeserializer.getTextValue(element, ModelProfileOnNode.NODE_ID), node1); + } else { + assertEquals("toXContent has the wrong node id", JsonDeserializer.getTextValue(element, ModelProfileOnNode.NODE_ID), node2); + } + + } + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ProfileTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/ProfileTransportActionTests.java index e30a188fb..f522d89f1 100644 --- a/src/test/java/org/opensearch/ad/transport/ProfileTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/ProfileTransportActionTests.java @@ -29,7 +29,6 @@ import org.junit.Test; import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.caching.CacheProvider; import org.opensearch.ad.caching.EntityCache; import org.opensearch.ad.feature.FeatureManager; @@ -41,6 +40,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.plugins.Plugin; import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.model.Entity; import org.opensearch.transport.TransportService; @@ -120,7 +120,7 @@ private void setUpModelSize(int maxModel) { @Override protected Collection> nodePlugins() { - return Arrays.asList(AnomalyDetectorPlugin.class); + return Arrays.asList(TimeSeriesAnalyticsPlugin.class); } @Test diff --git a/src/test/java/org/opensearch/ad/transport/ProfileTransportActionTests.java-e b/src/test/java/org/opensearch/ad/transport/ProfileTransportActionTests.java-e new file mode 100644 index 000000000..f522d89f1 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ProfileTransportActionTests.java-e @@ -0,0 +1,232 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.caching.CacheProvider; +import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.model.DetectorProfileName; +import org.opensearch.ad.model.ModelProfile; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.settings.Settings; +import org.opensearch.plugins.Plugin; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.transport.TransportService; + +public class ProfileTransportActionTests extends OpenSearchIntegTestCase { + private ProfileTransportAction action; + private String detectorId = "Pl536HEBnXkDrah03glg"; + String node1, nodeName1; + DiscoveryNode discoveryNode1; + Set profilesToRetrieve = new HashSet(); + private int shingleSize = 6; + private long modelSize = 4456448L; + private String modelId = "Pl536HEBnXkDrah03glg_model_rcf_1"; + private CacheProvider cacheProvider; + private int activeEntities = 10; + private long totalUpdates = 127; + private long multiEntityModelSize = 712480L; + private ModelManager modelManager; + private FeatureManager featureManager; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + + modelManager = mock(ModelManager.class); + featureManager = mock(FeatureManager.class); + + when(featureManager.getShingleSize(any(String.class))).thenReturn(shingleSize); + + EntityCache cache = mock(EntityCache.class); + cacheProvider = mock(CacheProvider.class); + when(cacheProvider.get()).thenReturn(cache); + when(cache.getActiveEntities(anyString())).thenReturn(activeEntities); + when(cache.getTotalUpdates(anyString())).thenReturn(totalUpdates); + Map multiEntityModelSizeMap = new HashMap<>(); + String modelId1 = "T4c3dXUBj-2IZN7itix__entity_app_3"; + String modelId2 = "T4c3dXUBj-2IZN7itix__entity_app_2"; + multiEntityModelSizeMap.put(modelId1, multiEntityModelSize); + multiEntityModelSizeMap.put(modelId2, multiEntityModelSize); + when(cache.getModelSize(anyString())).thenReturn(multiEntityModelSizeMap); + + List modelProfiles = new ArrayList<>(); + String field = "field"; + String fieldVal1 = "value1"; + String fieldVal2 = "value2"; + Entity entity1 = Entity.createSingleAttributeEntity(field, fieldVal1); + Entity entity2 = Entity.createSingleAttributeEntity(field, fieldVal2); + modelProfiles.add(new ModelProfile(modelId1, entity1, multiEntityModelSize)); + modelProfiles.add(new ModelProfile(modelId1, entity2, multiEntityModelSize)); + when(cache.getAllModelProfile(anyString())).thenReturn(modelProfiles); + + Map modelSizes = new HashMap<>(); + modelSizes.put(modelId, modelSize); + when(modelManager.getModelSize(any(String.class))).thenReturn(modelSizes); + + Settings settings = Settings.builder().put("plugins.anomaly_detection.max_model_size_per_node", 100).build(); + + action = new ProfileTransportAction( + client().threadPool(), + clusterService(), + mock(TransportService.class), + mock(ActionFilters.class), + modelManager, + featureManager, + cacheProvider, + settings + ); + + profilesToRetrieve = new HashSet(); + profilesToRetrieve.add(DetectorProfileName.COORDINATING_NODE); + } + + private void setUpModelSize(int maxModel) { + Settings nodeSettings = Settings.builder().put(AnomalyDetectorSettings.MAX_MODEL_SIZE_PER_NODE.getKey(), maxModel).build(); + internalCluster().startNode(nodeSettings); + } + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(TimeSeriesAnalyticsPlugin.class); + } + + @Test + public void testNewResponse() { + setUpModelSize(100); + DiscoveryNode node = clusterService().localNode(); + ProfileRequest profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, false, node); + + ProfileNodeResponse profileNodeResponse1 = new ProfileNodeResponse(node, new HashMap<>(), shingleSize, 0, 0, new ArrayList<>(), 0); + List profileNodeResponses = Arrays.asList(profileNodeResponse1); + List failures = new ArrayList<>(); + + ProfileResponse profileResponse = action.newResponse(profileRequest, profileNodeResponses, failures); + assertEquals(node.getId(), profileResponse.getCoordinatingNode()); + } + + @Test + public void testNewNodeRequest() { + setUpModelSize(100); + ProfileRequest profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, false); + + ProfileNodeRequest profileNodeRequest1 = new ProfileNodeRequest(profileRequest); + ProfileNodeRequest profileNodeRequest2 = action.newNodeRequest(profileRequest); + + assertEquals(profileNodeRequest1.getId(), profileNodeRequest2.getId()); + assertEquals(profileNodeRequest2.getProfilesToBeRetrieved(), profileNodeRequest2.getProfilesToBeRetrieved()); + } + + @Test + public void testNodeOperation() { + setUpModelSize(100); + DiscoveryNode nodeId = clusterService().localNode(); + ProfileRequest profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, false, nodeId); + + ProfileNodeResponse response = action.nodeOperation(new ProfileNodeRequest(profileRequest)); + + assertEquals(shingleSize, response.getShingleSize()); + assertEquals(null, response.getModelSize()); + + profilesToRetrieve = new HashSet(); + profilesToRetrieve.add(DetectorProfileName.TOTAL_SIZE_IN_BYTES); + + profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, false, nodeId); + response = action.nodeOperation(new ProfileNodeRequest(profileRequest)); + + assertEquals(-1, response.getShingleSize()); + assertEquals(1, response.getModelSize().size()); + assertEquals(modelSize, response.getModelSize().get(modelId).longValue()); + } + + @Test + public void testMultiEntityNodeOperation() { + setUpModelSize(100); + DiscoveryNode nodeId = clusterService().localNode(); + profilesToRetrieve = new HashSet(); + profilesToRetrieve.add(DetectorProfileName.ACTIVE_ENTITIES); + ProfileRequest profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, true, nodeId); + + ProfileNodeResponse response = action.nodeOperation(new ProfileNodeRequest(profileRequest)); + + assertEquals(activeEntities, response.getActiveEntities()); + assertEquals(null, response.getModelSize()); + + profilesToRetrieve.add(DetectorProfileName.INIT_PROGRESS); + + profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, true, nodeId); + response = action.nodeOperation(new ProfileNodeRequest(profileRequest)); + + assertEquals(activeEntities, response.getActiveEntities()); + assertEquals(null, response.getModelSize()); + assertEquals(totalUpdates, response.getTotalUpdates()); + + profilesToRetrieve.add(DetectorProfileName.MODELS); + profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, true, nodeId); + response = action.nodeOperation(new ProfileNodeRequest(profileRequest)); + + assertEquals(activeEntities, response.getActiveEntities()); + assertEquals(null, response.getModelSize()); + assertEquals(2, response.getModelProfiles().size()); + assertEquals(totalUpdates, response.getTotalUpdates()); + assertEquals(2, response.getModelCount()); + } + + @Test + public void testModelCount() { + setUpModelSize(1); + + Settings settings = Settings.builder().put("plugins.anomaly_detection.max_model_size_per_node", 1).build(); + + action = new ProfileTransportAction( + client().threadPool(), + clusterService(), + mock(TransportService.class), + mock(ActionFilters.class), + modelManager, + featureManager, + cacheProvider, + settings + ); + + DiscoveryNode nodeId = clusterService().localNode(); + profilesToRetrieve = new HashSet(); + profilesToRetrieve.add(DetectorProfileName.MODELS); + ProfileRequest profileRequest = new ProfileRequest(detectorId, profilesToRetrieve, true, nodeId); + ProfileNodeResponse response = action.nodeOperation(new ProfileNodeRequest(profileRequest)); + assertEquals(2, response.getModelCount()); + assertEquals(1, response.getModelProfiles().size()); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java b/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java index cf6e45e14..edb480dd1 100644 --- a/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java +++ b/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java @@ -36,9 +36,9 @@ import org.opensearch.ad.ml.SingleStreamModelIdMapper; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; import org.opensearch.common.transport.TransportAddress; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.tasks.Task; import org.opensearch.timeseries.AbstractTimeSeriesTest; diff --git a/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java-e b/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java-e new file mode 100644 index 000000000..5626b608b --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java-e @@ -0,0 +1,365 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.Collections; +import java.util.Optional; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.common.exception.JsonPathNotFoundException; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.SingleStreamModelIdMapper; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.transport.TransportAddress; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.transport.ConnectTransportException; +import org.opensearch.transport.Transport; +import org.opensearch.transport.TransportException; +import org.opensearch.transport.TransportInterceptor; +import org.opensearch.transport.TransportRequest; +import org.opensearch.transport.TransportRequestOptions; +import org.opensearch.transport.TransportResponse; +import org.opensearch.transport.TransportResponseHandler; +import org.opensearch.transport.TransportService; + +import test.org.opensearch.ad.util.FakeNode; +import test.org.opensearch.ad.util.JsonDeserializer; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; + +public class RCFPollingTests extends AbstractTimeSeriesTest { + Gson gson = new GsonBuilder().create(); + private String detectorId = "jqIG6XIBEyaF3zCMZfcB"; + private String model0Id; + private long totalUpdates = 3L; + private String nodeId = "abc"; + private ClusterService clusterService; + private HashRing hashRing; + private TransportAddress transportAddress1; + private ModelManager manager; + private TransportService transportService; + private PlainActionFuture future; + private RCFPollingTransportAction action; + private RCFPollingRequest request; + private TransportInterceptor normalTransportInterceptor, failureTransportInterceptor; + + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(RCFPollingTests.class.getSimpleName()); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + } + + private void registerHandler(FakeNode node) { + new RCFPollingTransportAction( + new ActionFilters(Collections.emptySet()), + node.transportService, + Settings.EMPTY, + manager, + hashRing, + node.clusterService + ); + } + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + clusterService = mock(ClusterService.class); + hashRing = mock(HashRing.class); + transportAddress1 = new TransportAddress(new InetSocketAddress(InetAddress.getByName("1.2.3.4"), 9300)); + manager = mock(ModelManager.class); + transportService = new TransportService( + Settings.EMPTY, + mock(Transport.class), + null, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + x -> null, + null, + Collections.emptySet() + ); + future = new PlainActionFuture<>(); + + request = new RCFPollingRequest(detectorId); + model0Id = SingleStreamModelIdMapper.getRcfModelId(detectorId, 0); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) args[2]; + listener.onResponse(totalUpdates); + return null; + }).when(manager).getTotalUpdates(any(String.class), any(String.class), any()); + + normalTransportInterceptor = new TransportInterceptor() { + @Override + public AsyncSender interceptSender(AsyncSender sender) { + return new AsyncSender() { + @Override + public void sendRequest( + Transport.Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + if (RCFPollingAction.NAME.equals(action)) { + sender.sendRequest(connection, action, request, options, rcfRollingHandler(handler)); + } else { + sender.sendRequest(connection, action, request, options, handler); + } + } + }; + } + }; + + failureTransportInterceptor = new TransportInterceptor() { + @Override + public AsyncSender interceptSender(AsyncSender sender) { + return new AsyncSender() { + @Override + public void sendRequest( + Transport.Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + if (RCFPollingAction.NAME.equals(action)) { + sender.sendRequest(connection, action, request, options, rcfFailureRollingHandler(handler)); + } else { + sender.sendRequest(connection, action, request, options, handler); + } + } + }; + } + }; + } + + public void testDoubleNaN() { + try { + gson.toJson(Double.NaN); + } catch (Exception e) { + assertTrue(e instanceof IllegalArgumentException); + assertTrue(e.getMessage().contains("NaN is not a valid double value as per JSON specification")); + } + + Gson gson = new GsonBuilder().serializeSpecialFloatingPointValues().create(); + String json = gson.toJson(Double.NaN); + assertEquals("NaN", json); + Double value = gson.fromJson(json, Double.class); + assertTrue(value.isNaN()); + } + + public void testNormal() { + DiscoveryNode localNode = new DiscoveryNode(nodeId, transportAddress1, Version.CURRENT.minimumCompatibilityVersion()); + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(Optional.of(localNode)); + + when(clusterService.localNode()).thenReturn(localNode); + + action = new RCFPollingTransportAction( + mock(ActionFilters.class), + transportService, + Settings.EMPTY, + manager, + hashRing, + clusterService + ); + action.doExecute(mock(Task.class), request, future); + + RCFPollingResponse response = future.actionGet(); + assertEquals(totalUpdates, response.getTotalUpdates()); + } + + public void testNoNodeFoundForModel() { + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(Optional.empty()); + action = new RCFPollingTransportAction( + mock(ActionFilters.class), + transportService, + Settings.EMPTY, + manager, + hashRing, + clusterService + ); + action.doExecute(mock(Task.class), request, future); + assertException(future, TimeSeriesException.class, RCFPollingTransportAction.NO_NODE_FOUND_MSG); + } + + /** + * Precondition: receiver's model manager respond with a response. See + * manager.getRcfModelId mocked output in setUp method. + * When receiving a response, respond back with totalUpdates. + * @param handler handler for receiver + * @return handler for request sender + */ + private TransportResponseHandler rcfRollingHandler(TransportResponseHandler handler) { + return new TransportResponseHandler() { + @Override + public T read(StreamInput in) throws IOException { + return handler.read(in); + } + + @Override + @SuppressWarnings("unchecked") + public void handleResponse(T response) { + handler.handleResponse((T) new RCFPollingResponse(totalUpdates)); + } + + @Override + public void handleException(TransportException exp) { + handler.handleException(exp); + } + + @Override + public String executor() { + return handler.executor(); + } + }; + } + + /** + * Precondition: receiver's model manager respond with a response. See + * manager.getRcfModelId mocked output in setUp method. + * Create handler that would return a connection failure + * @param handler callback handler + * @return handlder that would return a connection failure + */ + private TransportResponseHandler rcfFailureRollingHandler(TransportResponseHandler handler) { + return new TransportResponseHandler() { + @Override + public T read(StreamInput in) throws IOException { + return handler.read(in); + } + + @Override + public void handleResponse(T response) { + handler + .handleException( + new ConnectTransportException( + new DiscoveryNode(nodeId, transportAddress1, Version.CURRENT.minimumCompatibilityVersion()), + RCFPollingAction.NAME + ) + ); + } + + @Override + public void handleException(TransportException exp) { + handler.handleException(exp); + } + + @Override + public String executor() { + return handler.executor(); + } + }; + } + + public void testGetRemoteNormalResponse() { + setupTestNodes(normalTransportInterceptor, Settings.EMPTY); + try { + TransportService realTransportService = testNodes[0].transportService; + clusterService = testNodes[0].clusterService; + + action = new RCFPollingTransportAction( + new ActionFilters(Collections.emptySet()), + realTransportService, + Settings.EMPTY, + manager, + hashRing, + clusterService + ); + + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + .thenReturn(Optional.of(testNodes[1].discoveryNode())); + registerHandler(testNodes[1]); + + action.doExecute(null, request, future); + + RCFPollingResponse response = future.actionGet(); + assertEquals(totalUpdates, response.getTotalUpdates()); + } finally { + tearDownTestNodes(); + } + } + + public void testGetRemoteFailureResponse() { + setupTestNodes(failureTransportInterceptor, Settings.EMPTY); + try { + TransportService realTransportService = testNodes[0].transportService; + clusterService = testNodes[0].clusterService; + + action = new RCFPollingTransportAction( + new ActionFilters(Collections.emptySet()), + realTransportService, + Settings.EMPTY, + manager, + hashRing, + clusterService + ); + + when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + .thenReturn(Optional.of(testNodes[1].discoveryNode())); + registerHandler(testNodes[1]); + + action.doExecute(null, request, future); + + expectThrows(ConnectTransportException.class, () -> future.actionGet()); + } finally { + tearDownTestNodes(); + } + } + + public void testResponseToXContent() throws IOException, JsonPathNotFoundException { + RCFPollingResponse response = new RCFPollingResponse(totalUpdates); + String json = TestHelpers.xContentBuilderToString(response.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + assertEquals(totalUpdates, JsonDeserializer.getLongValue(json, RCFPollingResponse.TOTAL_UPDATES_KEY)); + } + + public void testRequestToXContent() throws IOException, JsonPathNotFoundException { + RCFPollingRequest response = new RCFPollingRequest(detectorId); + String json = TestHelpers.xContentBuilderToString(response.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + assertEquals(detectorId, JsonDeserializer.getTextValue(json, ADCommonName.ID_JSON_KEY)); + } + + public void testNullDetectorId() { + String nullDetectorId = null; + RCFPollingRequest emptyRequest = new RCFPollingRequest(nullDetectorId); + assertTrue(emptyRequest.validate() != null); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/RCFResultITTests.java b/src/test/java/org/opensearch/ad/transport/RCFResultITTests.java index 2354d4657..77a8f65ff 100644 --- a/src/test/java/org/opensearch/ad/transport/RCFResultITTests.java +++ b/src/test/java/org/opensearch/ad/transport/RCFResultITTests.java @@ -17,19 +17,19 @@ import org.opensearch.action.ActionFuture; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; public class RCFResultITTests extends OpenSearchIntegTestCase { @Override protected Collection> nodePlugins() { - return Collections.singletonList(AnomalyDetectorPlugin.class); + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); } protected Collection> transportClientPlugins() { - return Collections.singletonList(AnomalyDetectorPlugin.class); + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); } public void testEmptyFeature() throws ExecutionException, InterruptedException { diff --git a/src/test/java/org/opensearch/ad/transport/RCFResultITTests.java-e b/src/test/java/org/opensearch/ad/transport/RCFResultITTests.java-e new file mode 100644 index 000000000..77a8f65ff --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/RCFResultITTests.java-e @@ -0,0 +1,50 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.util.Collection; +import java.util.Collections; +import java.util.concurrent.ExecutionException; + +import org.opensearch.action.ActionFuture; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.plugins.Plugin; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +public class RCFResultITTests extends OpenSearchIntegTestCase { + + @Override + protected Collection> nodePlugins() { + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); + } + + protected Collection> transportClientPlugins() { + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); + } + + public void testEmptyFeature() throws ExecutionException, InterruptedException { + RCFResultRequest request = new RCFResultRequest("123", "123-rcfmodel-1", new double[] {}); + + ActionFuture future = client().execute(RCFResultAction.INSTANCE, request); + + expectThrows(ActionRequestValidationException.class, () -> future.actionGet()); + } + + public void testIDIsNull() throws ExecutionException, InterruptedException { + RCFResultRequest request = new RCFResultRequest(null, "123-rcfmodel-1", new double[] { 0 }); + + ActionFuture future = client().execute(RCFResultAction.INSTANCE, request); + + expectThrows(ActionRequestValidationException.class, () -> future.actionGet()); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/RCFResultTests.java b/src/test/java/org/opensearch/ad/transport/RCFResultTests.java index 7b46a0697..8f26af293 100644 --- a/src/test/java/org/opensearch/ad/transport/RCFResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/RCFResultTests.java @@ -51,8 +51,8 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.tasks.Task; diff --git a/src/test/java/org/opensearch/ad/transport/RCFResultTests.java-e b/src/test/java/org/opensearch/ad/transport/RCFResultTests.java-e new file mode 100644 index 000000000..801814611 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/RCFResultTests.java-e @@ -0,0 +1,366 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasItem; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +import org.hamcrest.Matchers; +import org.junit.Before; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.common.exception.JsonPathNotFoundException; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.stats.ADStat; +import org.opensearch.ad.stats.ADStats; +import org.opensearch.ad.stats.suppliers.CounterSupplier; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.Strings; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.transport.Transport; +import org.opensearch.transport.TransportService; + +import test.org.opensearch.ad.util.JsonDeserializer; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; + +public class RCFResultTests extends OpenSearchTestCase { + Gson gson = new GsonBuilder().create(); + + private double[] attribution = new double[] { 1. }; + private HashRing hashRing; + private DiscoveryNode node; + private long totalUpdates = 32; + private double grade = 0.5; + private double[] pastValues = new double[] { 123, 456 }; + private double[][] expectedValuesList = new double[][] { new double[] { 789, 12 } }; + private double[] likelihood = new double[] { randomDouble() }; + private double threshold = 1.1d; + private ADStats adStats; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + hashRing = mock(HashRing.class); + node = mock(DiscoveryNode.class); + doReturn(Optional.of(node)).when(hashRing).getNodeByAddress(any()); + Map> statsMap = new HashMap>() { + { + put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.MODEL_CORRUTPION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + } + }; + + adStats = new ADStats(statsMap); + } + + @SuppressWarnings("unchecked") + public void testNormal() { + TransportService transportService = new TransportService( + Settings.EMPTY, + mock(Transport.class), + null, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + x -> null, + null, + Collections.emptySet() + ); + + ModelManager manager = mock(ModelManager.class); + ADCircuitBreakerService adCircuitBreakerService = mock(ADCircuitBreakerService.class); + RCFResultTransportAction action = new RCFResultTransportAction( + mock(ActionFilters.class), + transportService, + manager, + adCircuitBreakerService, + hashRing, + adStats + ); + + double rcfScore = 0.5; + int forestSize = 25; + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener + .onResponse( + new ThresholdingResult( + grade, + 0d, + rcfScore, + totalUpdates, + 0, + attribution, + pastValues, + expectedValuesList, + likelihood, + threshold, + forestSize + ) + ); + return null; + }).when(manager).getTRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); + + when(adCircuitBreakerService.isOpen()).thenReturn(false); + + final PlainActionFuture future = new PlainActionFuture<>(); + RCFResultRequest request = new RCFResultRequest("123", "123-rcf-1", new double[] { 0 }); + action.doExecute(mock(Task.class), request, future); + + RCFResultResponse response = future.actionGet(); + assertEquals(rcfScore, response.getRCFScore(), 0.001); + assertEquals(forestSize, response.getForestSize(), 0.001); + assertTrue(Arrays.equals(attribution, response.getAttribution())); + } + + @SuppressWarnings("unchecked") + public void testExecutionException() { + TransportService transportService = new TransportService( + Settings.EMPTY, + mock(Transport.class), + null, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + x -> null, + null, + Collections.emptySet() + ); + + ModelManager manager = mock(ModelManager.class); + ADCircuitBreakerService adCircuitBreakerService = mock(ADCircuitBreakerService.class); + RCFResultTransportAction action = new RCFResultTransportAction( + mock(ActionFilters.class), + transportService, + manager, + adCircuitBreakerService, + hashRing, + adStats + ); + doThrow(NullPointerException.class) + .when(manager) + .getTRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); + when(adCircuitBreakerService.isOpen()).thenReturn(false); + + final PlainActionFuture future = new PlainActionFuture<>(); + RCFResultRequest request = new RCFResultRequest("123", "123-rcf-1", new double[] { 0 }); + action.doExecute(mock(Task.class), request, future); + + expectThrows(NullPointerException.class, () -> future.actionGet()); + } + + public void testSerialzationResponse() throws IOException { + RCFResultResponse response = new RCFResultResponse( + 0.3, + 0, + 26, + attribution, + totalUpdates, + grade, + Version.CURRENT, + 0, + null, + null, + null, + 1.1 + ); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + RCFResultResponse readResponse = RCFResultAction.INSTANCE.getResponseReader().read(streamInput); + assertThat(response.getForestSize(), equalTo(readResponse.getForestSize())); + assertThat(response.getRCFScore(), equalTo(readResponse.getRCFScore())); + assertArrayEquals(response.getAttribution(), readResponse.getAttribution(), 1e-6); + } + + public void testJsonResponse() throws IOException, JsonPathNotFoundException { + RCFResultResponse response = new RCFResultResponse( + 0.3, + 0, + 26, + attribution, + totalUpdates, + grade, + Version.CURRENT, + 0, + null, + null, + null, + 1.1 + ); + XContentBuilder builder = jsonBuilder(); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String json = Strings.toString(builder); + assertEquals(JsonDeserializer.getDoubleValue(json, RCFResultResponse.RCF_SCORE_JSON_KEY), response.getRCFScore(), 0.001); + assertEquals(JsonDeserializer.getDoubleValue(json, RCFResultResponse.FOREST_SIZE_JSON_KEY), response.getForestSize(), 0.001); + assertTrue( + Arrays.equals(JsonDeserializer.getDoubleArrayValue(json, RCFResultResponse.ATTRIBUTION_JSON_KEY), response.getAttribution()) + ); + } + + public void testEmptyID() { + ActionRequestValidationException e = new RCFResultRequest(null, "123-rcf-1", new double[] { 0 }).validate(); + assertThat(e.validationErrors(), Matchers.hasItem(ADCommonMessages.AD_ID_MISSING_MSG)); + } + + public void testFeatureIsNull() { + ActionRequestValidationException e = new RCFResultRequest("123", "123-rcf-1", null).validate(); + assertThat(e.validationErrors(), hasItem(RCFResultRequest.INVALID_FEATURE_MSG)); + } + + public void testSerialzationRequest() throws IOException { + RCFResultRequest response = new RCFResultRequest("123", "123-rcf-1", new double[] { 0 }); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + RCFResultRequest readResponse = new RCFResultRequest(streamInput); + assertThat(response.getAdID(), equalTo(readResponse.getAdID())); + assertThat(response.getFeatures(), equalTo(readResponse.getFeatures())); + } + + public void testJsonRequest() throws IOException, JsonPathNotFoundException { + RCFResultRequest request = new RCFResultRequest("123", "123-rcf-1", new double[] { 0 }); + XContentBuilder builder = jsonBuilder(); + request.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String json = Strings.toString(builder); + assertEquals(JsonDeserializer.getTextValue(json, ADCommonName.ID_JSON_KEY), request.getAdID()); + assertArrayEquals(JsonDeserializer.getDoubleArrayValue(json, ADCommonName.FEATURE_JSON_KEY), request.getFeatures(), 0.001); + } + + @SuppressWarnings("unchecked") + public void testCircuitBreaker() { + TransportService transportService = new TransportService( + Settings.EMPTY, + mock(Transport.class), + null, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + x -> null, + null, + Collections.emptySet() + ); + + ModelManager manager = mock(ModelManager.class); + ADCircuitBreakerService breakerService = mock(ADCircuitBreakerService.class); + RCFResultTransportAction action = new RCFResultTransportAction( + mock(ActionFilters.class), + transportService, + manager, + breakerService, + hashRing, + adStats + ); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener + .onResponse( + new ThresholdingResult( + grade, + 0d, + 0.5, + totalUpdates, + 0, + attribution, + pastValues, + expectedValuesList, + likelihood, + threshold, + 30 + ) + ); + return null; + }).when(manager).getTRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); + when(breakerService.isOpen()).thenReturn(true); + + final PlainActionFuture future = new PlainActionFuture<>(); + RCFResultRequest request = new RCFResultRequest("123", "123-rcf-1", new double[] { 0 }); + action.doExecute(mock(Task.class), request, future); + + expectThrows(LimitExceededException.class, () -> future.actionGet()); + } + + @SuppressWarnings("unchecked") + public void testCorruptModel() { + TransportService transportService = new TransportService( + Settings.EMPTY, + mock(Transport.class), + null, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + x -> null, + null, + Collections.emptySet() + ); + + ModelManager manager = mock(ModelManager.class); + ADCircuitBreakerService adCircuitBreakerService = mock(ADCircuitBreakerService.class); + RCFResultTransportAction action = new RCFResultTransportAction( + mock(ActionFilters.class), + transportService, + manager, + adCircuitBreakerService, + hashRing, + adStats + ); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onFailure(new IllegalArgumentException()); + return null; + }).when(manager).getTRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); + + when(adCircuitBreakerService.isOpen()).thenReturn(false); + + final PlainActionFuture future = new PlainActionFuture<>(); + String detectorId = "123"; + RCFResultRequest request = new RCFResultRequest(detectorId, "123-rcf-1", new double[] { 0 }); + action.doExecute(mock(Task.class), request, future); + + expectThrows(IllegalArgumentException.class, () -> future.actionGet()); + Object val = adStats.getStat(StatNames.MODEL_CORRUTPION_COUNT.getName()).getValue(); + assertEquals(1L, ((Long) val).longValue()); + verify(manager, times(1)).clear(eq(detectorId), any()); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/SearchADTasksActionTests.java-e b/src/test/java/org/opensearch/ad/transport/SearchADTasksActionTests.java-e new file mode 100644 index 000000000..0d7ea4d9d --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/SearchADTasksActionTests.java-e @@ -0,0 +1,43 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.timeseries.TestHelpers.matchAllRequest; + +import java.io.IOException; + +import org.junit.Test; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.HistoricalAnalysisIntegTestCase; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.timeseries.TestHelpers; + +public class SearchADTasksActionTests extends HistoricalAnalysisIntegTestCase { + + @Test + public void testSearchADTasksAction() throws IOException { + createDetectionStateIndex(); + String adTaskId = createADTask(TestHelpers.randomAdTask()); + + SearchResponse searchResponse = client().execute(SearchADTasksAction.INSTANCE, matchAllRequest()).actionGet(10000); + assertEquals(1, searchResponse.getInternalResponse().hits().getTotalHits().value); + assertEquals(adTaskId, searchResponse.getInternalResponse().hits().getAt(0).getId()); + } + + @Test + public void testNoIndex() { + deleteIndexIfExists(ADCommonName.DETECTION_STATE_INDEX); + SearchResponse searchResponse = client().execute(SearchADTasksAction.INSTANCE, matchAllRequest()).actionGet(10000); + assertEquals(0, searchResponse.getInternalResponse().hits().getTotalHits().value); + } + +} diff --git a/src/test/java/org/opensearch/ad/transport/SearchADTasksTransportActionTests.java-e b/src/test/java/org/opensearch/ad/transport/SearchADTasksTransportActionTests.java-e new file mode 100644 index 000000000..bc87faf13 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/SearchADTasksTransportActionTests.java-e @@ -0,0 +1,90 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; + +import org.junit.Before; +import org.junit.Ignore; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.HistoricalAnalysisIntegTestCase; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.ADTask; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.test.OpenSearchIntegTestCase; + +@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 2) +public class SearchADTasksTransportActionTests extends HistoricalAnalysisIntegTestCase { + + private Instant startTime; + private Instant endTime; + private String type = "error"; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + startTime = Instant.now().minus(10, ChronoUnit.DAYS); + endTime = Instant.now(); + ingestTestData(testIndex, startTime, detectionIntervalInMinutes, type, 2000); + } + + @Override + protected Settings nodeSettings(int nodeOrdinal) { + return Settings + .builder() + .put(super.nodeSettings(nodeOrdinal)) + .put(BATCH_TASK_PIECE_INTERVAL_SECONDS.getKey(), 1) + .put(MAX_BATCH_TASK_PER_NODE.getKey(), 1) + .build(); + } + + public void testSearchWithoutTaskIndex() { + SearchRequest request = searchRequest(false); + expectThrows(IndexNotFoundException.class, () -> client().execute(SearchADTasksAction.INSTANCE, request).actionGet(10000)); + } + + public void testSearchWithNoTasks() throws IOException { + createDetectionStateIndex(); + SearchRequest request = searchRequest(false); + SearchResponse response = client().execute(SearchADTasksAction.INSTANCE, request).actionGet(10000); + assertEquals(0, response.getHits().getTotalHits().value); + } + + @Ignore + public void testSearchWithExistingTask() throws IOException { + startHistoricalAnalysis(startTime, endTime); + SearchRequest searchRequest = searchRequest(true); + SearchResponse response = client().execute(SearchADTasksAction.INSTANCE, searchRequest).actionGet(10000); + assertEquals(1, response.getHits().getTotalHits().value); + } + + private SearchRequest searchRequest(boolean isLatest) { + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(ADTask.IS_LATEST_FIELD, isLatest)); + sourceBuilder.query(query); + SearchRequest request = new SearchRequest().source(sourceBuilder).indices(ADCommonName.DETECTION_STATE_INDEX); + return request; + } + +} diff --git a/src/test/java/org/opensearch/ad/transport/SearchAnomalyDetectorActionTests.java-e b/src/test/java/org/opensearch/ad/transport/SearchAnomalyDetectorActionTests.java-e new file mode 100644 index 000000000..96099af31 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/SearchAnomalyDetectorActionTests.java-e @@ -0,0 +1,75 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.model.AnomalyDetector.DETECTOR_TYPE_FIELD; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.HistoricalAnalysisIntegTestCase; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorType; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonName; + +import com.google.common.collect.ImmutableList; + +public class SearchAnomalyDetectorActionTests extends HistoricalAnalysisIntegTestCase { + + private String indexName = "test-data"; + private Instant startTime = Instant.now().minus(2, ChronoUnit.DAYS); + + public void testSearchDetectorAction() throws IOException { + ingestTestData(indexName, startTime, 1, "test", 3000); + String detectorType = AnomalyDetectorType.SINGLE_ENTITY.name(); + AnomalyDetector detector = TestHelpers + .randomAnomalyDetector( + ImmutableList.of(indexName), + ImmutableList.of(TestHelpers.randomFeature(true)), + null, + Instant.now(), + 1, + false, + null + ); + createDetectorIndex(); + String detectorId = createDetector(detector); + + BoolQueryBuilder query = new BoolQueryBuilder().filter(new TermQueryBuilder(DETECTOR_TYPE_FIELD, detectorType)); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); + SearchRequest request = new SearchRequest().source(searchSourceBuilder); + + SearchResponse searchResponse = client().execute(SearchAnomalyDetectorAction.INSTANCE, request).actionGet(10000); + assertEquals(1, searchResponse.getInternalResponse().hits().getTotalHits().value); + assertEquals(detectorId, searchResponse.getInternalResponse().hits().getAt(0).getId()); + } + + public void testNoIndex() { + deleteIndexIfExists(CommonName.CONFIG_INDEX); + + BoolQueryBuilder query = new BoolQueryBuilder().filter(new MatchAllQueryBuilder()); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); + SearchRequest request = new SearchRequest().source(searchSourceBuilder); + + SearchResponse searchResponse = client().execute(SearchAnomalyDetectorAction.INSTANCE, request).actionGet(10000); + assertEquals(0, searchResponse.getInternalResponse().hits().getTotalHits().value); + } + +} diff --git a/src/test/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoActionTests.java b/src/test/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoActionTests.java index 4e56ceb63..b67ec6aec 100644 --- a/src/test/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoActionTests.java @@ -34,10 +34,10 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchIntegTestCase; diff --git a/src/test/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoActionTests.java-e b/src/test/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoActionTests.java-e new file mode 100644 index 000000000..03d8aa943 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoActionTests.java-e @@ -0,0 +1,219 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.timeseries.TestHelpers.createEmptySearchResponse; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.transport.TransportService; + +public class SearchAnomalyDetectorInfoActionTests extends OpenSearchIntegTestCase { + private SearchAnomalyDetectorInfoRequest request; + private ActionListener response; + private SearchAnomalyDetectorInfoTransportAction action; + private Task task; + private ClusterService clusterService; + private Client client; + private ThreadPool threadPool; + ThreadContext threadContext; + private PlainActionFuture future; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + action = new SearchAnomalyDetectorInfoTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + client(), + clusterService() + ); + task = mock(Task.class); + response = new ActionListener() { + @Override + public void onResponse(SearchAnomalyDetectorInfoResponse response) { + Assert.assertEquals(response.getCount(), 0); + Assert.assertEquals(response.isNameExists(), false); + } + + @Override + public void onFailure(Exception e) { + Assert.assertTrue(true); + } + }; + + future = mock(PlainActionFuture.class); + client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + threadPool = mock(ThreadPool.class); + when(client.threadPool()).thenReturn(threadPool); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES))) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + } + + @Test + public void testSearchCount() throws IOException { + // Anomaly Detectors index will not exist, onResponse will be called + SearchAnomalyDetectorInfoRequest request = new SearchAnomalyDetectorInfoRequest(null, "count"); + action.doExecute(task, request, response); + } + + @Test + public void testSearchMatch() throws IOException { + // Anomaly Detectors index will not exist, onResponse will be called + SearchAnomalyDetectorInfoRequest request = new SearchAnomalyDetectorInfoRequest("testDetector", "match"); + action.doExecute(task, request, response); + } + + @Test + public void testSearchInfoAction() { + Assert.assertNotNull(SearchAnomalyDetectorInfoAction.INSTANCE.name()); + Assert.assertEquals(SearchAnomalyDetectorInfoAction.INSTANCE.name(), SearchAnomalyDetectorInfoAction.NAME); + } + + @Test + public void testSearchInfoRequest() throws IOException { + SearchAnomalyDetectorInfoRequest request = new SearchAnomalyDetectorInfoRequest("testDetector", "match"); + BytesStreamOutput out = new BytesStreamOutput(); + request.writeTo(out); + StreamInput input = out.bytes().streamInput(); + SearchAnomalyDetectorInfoRequest newRequest = new SearchAnomalyDetectorInfoRequest(input); + Assert.assertEquals(request.getName(), newRequest.getName()); + Assert.assertEquals(request.getRawPath(), newRequest.getRawPath()); + Assert.assertNull(newRequest.validate()); + } + + @Test + public void testSearchInfoResponse() throws IOException { + SearchAnomalyDetectorInfoResponse response = new SearchAnomalyDetectorInfoResponse(1, true); + BytesStreamOutput out = new BytesStreamOutput(); + response.writeTo(out); + StreamInput input = out.bytes().streamInput(); + SearchAnomalyDetectorInfoResponse newResponse = new SearchAnomalyDetectorInfoResponse(input); + Assert.assertEquals(response.getCount(), newResponse.getCount()); + Assert.assertEquals(response.isNameExists(), newResponse.isNameExists()); + Assert.assertNotNull(response.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + } + + public void testSearchInfoResponse_CountSuccessWithEmptyResponse() throws IOException { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + SearchResponse searchResponse = createEmptySearchResponse(); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + action = new SearchAnomalyDetectorInfoTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + client, + clusterService + ); + SearchAnomalyDetectorInfoRequest request = new SearchAnomalyDetectorInfoRequest("testDetector", "count"); + action.doExecute(task, request, future); + verify(future).onResponse(any(SearchAnomalyDetectorInfoResponse.class)); + } + + public void testSearchInfoResponse_MatchSuccessWithEmptyResponse() throws IOException { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + SearchResponse searchResponse = createEmptySearchResponse(); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + action = new SearchAnomalyDetectorInfoTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + client, + clusterService + ); + SearchAnomalyDetectorInfoRequest request = new SearchAnomalyDetectorInfoRequest("testDetector", "match"); + action.doExecute(task, request, future); + verify(future).onResponse(any(SearchAnomalyDetectorInfoResponse.class)); + } + + public void testSearchInfoResponse_CountRuntimeException() throws IOException { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + listener.onFailure(new RuntimeException("searchResponse failed!")); + return null; + }).when(client).search(any(), any()); + action = new SearchAnomalyDetectorInfoTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + client, + clusterService + ); + SearchAnomalyDetectorInfoRequest request = new SearchAnomalyDetectorInfoRequest("testDetector", "count"); + action.doExecute(task, request, future); + verify(future).onFailure(any(RuntimeException.class)); + } + + public void testSearchInfoResponse_MatchRuntimeException() throws IOException { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + listener.onFailure(new RuntimeException("searchResponse failed!")); + return null; + }).when(client).search(any(), any()); + action = new SearchAnomalyDetectorInfoTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + client, + clusterService + ); + SearchAnomalyDetectorInfoRequest request = new SearchAnomalyDetectorInfoRequest("testDetector", "match"); + action.doExecute(task, request, future); + verify(future).onFailure(any(RuntimeException.class)); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/SearchAnomalyResultActionTests.java-e b/src/test/java/org/opensearch/ad/transport/SearchAnomalyResultActionTests.java-e new file mode 100644 index 000000000..ac902a55e --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/SearchAnomalyResultActionTests.java-e @@ -0,0 +1,323 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ad.indices.ADIndexManagement.ALL_AD_RESULTS_INDEX_PATTERN; +import static org.opensearch.timeseries.TestHelpers.createClusterState; +import static org.opensearch.timeseries.TestHelpers.createSearchResponse; +import static org.opensearch.timeseries.TestHelpers.matchAllRequest; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; + +import org.apache.lucene.util.BytesRef; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.MultiSearchResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.ad.HistoricalAnalysisIntegTestCase; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.transport.handler.ADSearchHandler; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.BucketOrder; +import org.opensearch.search.aggregations.InternalOrder; +import org.opensearch.search.aggregations.bucket.terms.StringTerms; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.transport.Transport; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableList; + +public class SearchAnomalyResultActionTests extends HistoricalAnalysisIntegTestCase { + private SearchAnomalyResultTransportAction action; + private TransportService transportService; + private ThreadPool threadPool; + private ThreadContext threadContext; + private Client client; + private ClusterService clusterService; + private ActionFilters actionFilters; + private ADSearchHandler searchHandler; + private IndexNameExpressionResolver indexNameExpressionResolver; + private PlainActionFuture future; + private ClusterState clusterState; + private SearchResponse searchResponse; + private MultiSearchResponse multiSearchResponse; + private StringTerms resultIndicesAgg; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES))) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + clusterState = createClusterState(); + when(clusterService.state()).thenReturn(clusterState); + + transportService = new TransportService( + Settings.EMPTY, + mock(Transport.class), + null, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + x -> null, + null, + Collections.emptySet() + ); + + client = mock(Client.class); + threadPool = mock(ThreadPool.class); + when(client.threadPool()).thenReturn(threadPool); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + actionFilters = mock(ActionFilters.class); + searchHandler = mock(ADSearchHandler.class); + indexNameExpressionResolver = mock(IndexNameExpressionResolver.class); + action = new SearchAnomalyResultTransportAction( + transportService, + actionFilters, + searchHandler, + clusterService, + indexNameExpressionResolver, + client + ); + } + + @Test + public void testSearchAnomalyResult_NoIndices() { + future = mock(PlainActionFuture.class); + SearchRequest request = new SearchRequest().indices(new String[] {}); + + action.doExecute(mock(Task.class), request, future); + verify(future).onFailure(any(IllegalArgumentException.class)); + } + + @Test + public void testSearchAnomalyResult_NullAggregationInSearchResponse() { + future = mock(PlainActionFuture.class); + SearchRequest request = new SearchRequest().indices(new String[] { "opensearch-ad-plugin-result-test" }); + when(indexNameExpressionResolver.concreteIndexNames(clusterState, IndicesOptions.lenientExpandOpen(), request.indices())) + .thenReturn(new String[] { "opensearch-ad-plugin-result-test" }); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + listener.onResponse(createSearchResponse(TestHelpers.randomAnomalyDetectResult(0.87))); + return null; + }).when(client).search(any(), any()); + + action.doExecute(mock(Task.class), request, future); + verify(client).search(any(), any()); + } + + @Test + public void testSearchAnomalyResult_EmptyBucketsInSearchResponse() { + searchResponse = mock(SearchResponse.class); + resultIndicesAgg = new StringTerms( + "result_index", + InternalOrder.key(false), + BucketOrder.count(false), + 1, + 0, + Collections.emptyMap(), + DocValueFormat.RAW, + 1, + false, + 0, + ImmutableList.of(), + 0 + ); + List list = new ArrayList<>(); + list.add(resultIndicesAgg); + Aggregations aggregations = new Aggregations(list); + + when(searchResponse.getAggregations()).thenReturn(aggregations); + + action + .processSingleSearchResponse( + searchResponse, + mock(SearchRequest.class), + mock(PlainActionFuture.class), + new HashSet<>(), + new ArrayList<>() + ); + verify(searchHandler).search(any(), any()); + } + + @Test + public void testSearchAnomalyResult_NullBucketsInSearchResponse() { + searchResponse = mock(SearchResponse.class); + resultIndicesAgg = new StringTerms( + "result_index", + InternalOrder.key(false), + BucketOrder.count(false), + 1, + 0, + Collections.emptyMap(), + DocValueFormat.RAW, + 1, + false, + 0, + null, + 0 + ); + List list = new ArrayList<>(); + list.add(resultIndicesAgg); + Aggregations aggregations = new Aggregations(list); + + when(searchResponse.getAggregations()).thenReturn(aggregations); + + action + .processSingleSearchResponse( + searchResponse, + mock(SearchRequest.class), + mock(PlainActionFuture.class), + new HashSet<>(), + new ArrayList<>() + ); + verify(searchHandler).search(any(), any()); + } + + @Test + public void testMultiSearch_NoOnlyQueryCustomResultIndex() { + action + .multiSearch( + Arrays.asList("test"), + mock(SearchRequest.class), + mock(PlainActionFuture.class), + false, + threadContext.stashContext() + ); + + verify(client).multiSearch(any(), any()); + } + + @Test + public void testSearchAnomalyResult_MultiSearch() { + future = mock(PlainActionFuture.class); + SearchRequest request = new SearchRequest().indices(new String[] { "opensearch-ad-plugin-result-test" }); + when(indexNameExpressionResolver.concreteIndexNames(clusterState, IndicesOptions.lenientExpandOpen(), request.indices())) + .thenReturn(new String[] { "opensearch-ad-plugin-result-test" }); + + searchResponse = mock(SearchResponse.class); + resultIndicesAgg = new StringTerms( + "result_index", + InternalOrder.key(false), + BucketOrder.count(false), + 1, + 0, + Collections.emptyMap(), + DocValueFormat.RAW, + 1, + false, + 0, + createBuckets(), + 0 + ); + List list = new ArrayList<>(); + list.add(resultIndicesAgg); + Aggregations aggregations = new Aggregations(list); + + when(searchResponse.getAggregations()).thenReturn(aggregations); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + multiSearchResponse = mock(MultiSearchResponse.class); + MultiSearchResponse.Item multiSearchResponseItem = mock(MultiSearchResponse.Item.class); + when(multiSearchResponse.getResponses()).thenReturn(new MultiSearchResponse.Item[] { multiSearchResponseItem }); + when(multiSearchResponseItem.getFailure()).thenReturn(null); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + listener.onResponse(multiSearchResponse); + return null; + }).when(client).multiSearch(any(), any()); + + action.doExecute(mock(Task.class), request, future); + verify(client).search(any(), any()); + verify(client).multiSearch(any(), any()); + verify(searchHandler).search(any(), any()); + } + + @Test + public void testSearchResultAction() throws IOException { + createADResultIndex(); + String adResultId = createADResult(TestHelpers.randomAnomalyDetectResult()); + + SearchResponse searchResponse = client() + .execute(SearchAnomalyResultAction.INSTANCE, matchAllRequest().indices(ALL_AD_RESULTS_INDEX_PATTERN)) + .actionGet(10000); + assertEquals(1, searchResponse.getInternalResponse().hits().getTotalHits().value); + + assertEquals(adResultId, searchResponse.getInternalResponse().hits().getAt(0).getId()); + } + + @Test + public void testNoIndex() { + deleteIndexIfExists(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS); + SearchResponse searchResponse = client() + .execute(SearchAnomalyResultAction.INSTANCE, matchAllRequest().indices(ALL_AD_RESULTS_INDEX_PATTERN)) + .actionGet(10000); + assertEquals(0, searchResponse.getHits().getTotalHits().value); + } + + private List createBuckets() { + String entity1Name = "opensearch-ad-plugin-result-test"; + long entity1Count = 3; + StringTerms.Bucket entity1Bucket = new StringTerms.Bucket( + new BytesRef(entity1Name.getBytes(StandardCharsets.UTF_8), 0, entity1Name.getBytes(StandardCharsets.UTF_8).length), + entity1Count, + null, + false, + 0L, + DocValueFormat.RAW + ); + List stringBuckets = ImmutableList.of(entity1Bucket); + return stringBuckets; + } +} diff --git a/src/test/java/org/opensearch/ad/transport/SearchTopAnomalyResultActionTests.java-e b/src/test/java/org/opensearch/ad/transport/SearchTopAnomalyResultActionTests.java-e new file mode 100644 index 000000000..1e214f209 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/SearchTopAnomalyResultActionTests.java-e @@ -0,0 +1,268 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; + +import org.junit.Before; +import org.opensearch.ad.HistoricalAnalysisIntegTestCase; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.common.settings.Settings; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.TestHelpers; + +import com.google.common.collect.ImmutableList; + +// Only invalid test cases are covered here. This is due to issues with the lang-painless module not +// being installed on test clusters spun up in OpenSearchIntegTestCase classes (which this class extends), +// which is needed for executing the API on ingested data. Valid test cases are covered at the REST layer. +@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST) +public class SearchTopAnomalyResultActionTests extends HistoricalAnalysisIntegTestCase { + + private String testIndex; + private String detectorId; + private String taskId; + private Instant startTime; + private Instant endTime; + private ImmutableList categoryFields; + private String type = "error"; + + @Override + protected Settings nodeSettings(int nodeOrdinal) { + return Settings + .builder() + .put(super.nodeSettings(nodeOrdinal)) + .put(BATCH_TASK_PIECE_INTERVAL_SECONDS.getKey(), 1) + .put(MAX_BATCH_TASK_PER_NODE.getKey(), 1) + .build(); + } + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + testIndex = "test_data"; + taskId = "test-task-id"; + startTime = Instant.now().minus(10, ChronoUnit.DAYS); + endTime = Instant.now(); + categoryFields = ImmutableList.of("test-field-1", "test-field-2"); + ingestTestData(); + createSystemIndices(); + createAndIndexDetector(); + } + + private void ingestTestData() { + ingestTestData(testIndex, startTime, 1, "test", 1); + } + + private void createSystemIndices() throws IOException { + createDetectorIndex(); + createADResultIndex(); + } + + private void createAndIndexDetector() throws IOException { + AnomalyDetector detector = TestHelpers + .randomAnomalyDetector( + ImmutableList.of(testIndex), + ImmutableList.of(TestHelpers.randomFeature(true)), + null, + Instant.now(), + 1, + false, + categoryFields + ); + detectorId = createDetector(detector); + + } + + public void testInstanceAndNameValid() { + assertNotNull(SearchTopAnomalyResultAction.INSTANCE.name()); + assertEquals(SearchTopAnomalyResultAction.INSTANCE.name(), SearchTopAnomalyResultAction.NAME); + } + + public void testInvalidOrder() { + SearchTopAnomalyResultRequest searchRequest = new SearchTopAnomalyResultRequest( + detectorId, + taskId, + false, + 1, + Arrays.asList(categoryFields.get(0)), + "invalid-order", + startTime, + endTime + ); + expectThrows( + IllegalArgumentException.class, + () -> client().execute(SearchTopAnomalyResultAction.INSTANCE, searchRequest).actionGet(10_000) + ); + } + + public void testNegativeSize() { + SearchTopAnomalyResultRequest searchRequest = new SearchTopAnomalyResultRequest( + detectorId, + taskId, + false, + -1, + Arrays.asList(categoryFields.get(0)), + SearchTopAnomalyResultTransportAction.OrderType.SEVERITY.getName(), + startTime, + endTime + ); + expectThrows( + IllegalArgumentException.class, + () -> client().execute(SearchTopAnomalyResultAction.INSTANCE, searchRequest).actionGet(10_000) + ); + } + + public void testZeroSize() { + SearchTopAnomalyResultRequest searchRequest = new SearchTopAnomalyResultRequest( + detectorId, + taskId, + false, + 0, + Arrays.asList(categoryFields.get(0)), + SearchTopAnomalyResultTransportAction.OrderType.SEVERITY.getName(), + startTime, + endTime + ); + expectThrows( + IllegalArgumentException.class, + () -> client().execute(SearchTopAnomalyResultAction.INSTANCE, searchRequest).actionGet(10_000) + ); + } + + public void testTooLargeSize() { + SearchTopAnomalyResultRequest searchRequest = new SearchTopAnomalyResultRequest( + detectorId, + taskId, + false, + 9999999, + Arrays.asList(categoryFields.get(0)), + SearchTopAnomalyResultTransportAction.OrderType.SEVERITY.getName(), + startTime, + endTime + ); + expectThrows( + IllegalArgumentException.class, + () -> client().execute(SearchTopAnomalyResultAction.INSTANCE, searchRequest).actionGet(10_000) + ); + } + + public void testMissingStartTime() { + SearchTopAnomalyResultRequest searchRequest = new SearchTopAnomalyResultRequest( + detectorId, + taskId, + false, + 1, + Arrays.asList(categoryFields.get(0)), + SearchTopAnomalyResultTransportAction.OrderType.SEVERITY.getName(), + null, + endTime + ); + expectThrows( + IllegalArgumentException.class, + () -> client().execute(SearchTopAnomalyResultAction.INSTANCE, searchRequest).actionGet(10_000) + ); + } + + public void testMissingEndTime() { + SearchTopAnomalyResultRequest searchRequest = new SearchTopAnomalyResultRequest( + detectorId, + taskId, + false, + 1, + Arrays.asList(categoryFields.get(0)), + SearchTopAnomalyResultTransportAction.OrderType.SEVERITY.getName(), + startTime, + null + ); + expectThrows( + IllegalArgumentException.class, + () -> client().execute(SearchTopAnomalyResultAction.INSTANCE, searchRequest).actionGet(10_000) + ); + } + + public void testInvalidStartAndEndTimes() { + SearchTopAnomalyResultRequest searchRequest = new SearchTopAnomalyResultRequest( + detectorId, + taskId, + false, + 1, + Arrays.asList(categoryFields.get(0)), + SearchTopAnomalyResultTransportAction.OrderType.SEVERITY.getName(), + endTime, + startTime + ); + expectThrows( + IllegalArgumentException.class, + () -> client().execute(SearchTopAnomalyResultAction.INSTANCE, searchRequest).actionGet(10_000) + ); + + Instant curTimeInMillis = Instant.now(); + SearchTopAnomalyResultRequest searchRequest2 = new SearchTopAnomalyResultRequest( + detectorId, + taskId, + false, + 1, + Arrays.asList(categoryFields.get(0)), + SearchTopAnomalyResultTransportAction.OrderType.SEVERITY.getName(), + curTimeInMillis, + curTimeInMillis + ); + expectThrows( + IllegalArgumentException.class, + () -> client().execute(SearchTopAnomalyResultAction.INSTANCE, searchRequest2).actionGet(10_000) + ); + } + + public void testNoExistingHistoricalTask() throws IOException { + SearchTopAnomalyResultRequest searchRequest = new SearchTopAnomalyResultRequest( + detectorId, + taskId, + true, + 1, + Arrays.asList(categoryFields.get(0)), + SearchTopAnomalyResultTransportAction.OrderType.SEVERITY.getName(), + startTime, + endTime + ); + expectThrows(Exception.class, () -> client().execute(SearchTopAnomalyResultAction.INSTANCE, searchRequest).actionGet(10_000)); + } + + public void testSearchOnNonHCDetector() throws IOException { + AnomalyDetector nonHCDetector = TestHelpers + .randomAnomalyDetector( + ImmutableList.of(testIndex), + ImmutableList.of(TestHelpers.randomFeature(true)), + null, + Instant.now(), + 1, + false, + ImmutableList.of() + ); + String nonHCDetectorId = createDetector(nonHCDetector); + SearchTopAnomalyResultRequest searchRequest = new SearchTopAnomalyResultRequest( + nonHCDetectorId, + taskId, + false, + 1, + Arrays.asList(categoryFields.get(0)), + SearchTopAnomalyResultTransportAction.OrderType.SEVERITY.getName(), + startTime, + endTime + ); + expectThrows( + IllegalArgumentException.class, + () -> client().execute(SearchTopAnomalyResultAction.INSTANCE, searchRequest).actionGet(10_000) + ); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/SearchTopAnomalyResultRequestTests.java b/src/test/java/org/opensearch/ad/transport/SearchTopAnomalyResultRequestTests.java index 0668d71ca..d227a0392 100644 --- a/src/test/java/org/opensearch/ad/transport/SearchTopAnomalyResultRequestTests.java +++ b/src/test/java/org/opensearch/ad/transport/SearchTopAnomalyResultRequestTests.java @@ -15,7 +15,7 @@ import org.junit.Assert; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.timeseries.TestHelpers; diff --git a/src/test/java/org/opensearch/ad/transport/SearchTopAnomalyResultRequestTests.java-e b/src/test/java/org/opensearch/ad/transport/SearchTopAnomalyResultRequestTests.java-e new file mode 100644 index 000000000..d227a0392 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/SearchTopAnomalyResultRequestTests.java-e @@ -0,0 +1,189 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.junit.Assert; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.TestHelpers; + +public class SearchTopAnomalyResultRequestTests extends OpenSearchTestCase { + + public void testSerialization() throws IOException { + SearchTopAnomalyResultRequest originalRequest = new SearchTopAnomalyResultRequest( + "test-detector-id", + "test-task-id", + false, + 1, + Arrays.asList("test-field"), + SearchTopAnomalyResultTransportAction.OrderType.SEVERITY.getName(), + Instant.now().minus(10, ChronoUnit.DAYS), + Instant.now().minus(2, ChronoUnit.DAYS) + ); + + BytesStreamOutput output = new BytesStreamOutput(); + originalRequest.writeTo(output); + StreamInput input = output.bytes().streamInput(); + SearchTopAnomalyResultRequest parsedRequest = new SearchTopAnomalyResultRequest(input); + assertEquals(originalRequest.getId(), parsedRequest.getId()); + assertEquals(originalRequest.getTaskId(), parsedRequest.getTaskId()); + assertEquals(originalRequest.getHistorical(), parsedRequest.getHistorical()); + assertEquals(originalRequest.getSize(), parsedRequest.getSize()); + assertEquals(originalRequest.getCategoryFields(), parsedRequest.getCategoryFields()); + assertEquals(originalRequest.getOrder(), parsedRequest.getOrder()); + assertEquals(originalRequest.getStartTime(), parsedRequest.getStartTime()); + assertEquals(originalRequest.getEndTime(), parsedRequest.getEndTime()); + } + + public void testParse() throws IOException { + String detectorId = "test-detector-id"; + boolean historical = false; + String taskId = "test-task-id"; + int size = 5; + List categoryFields = Arrays.asList("field-1", "field-2"); + String order = "severity"; + Instant startTime = Instant.ofEpochMilli(1234); + Instant endTime = Instant.ofEpochMilli(5678); + + XContentBuilder xContentBuilder = TestHelpers + .builder() + .startObject() + .field("task_id", taskId) + .field("size", size) + .field("category_field", categoryFields) + .field("order", order) + .field("start_time_ms", startTime.toEpochMilli()) + .field("end_time_ms", endTime.toEpochMilli()) + .endObject(); + + String requestAsXContentString = TestHelpers.xContentBuilderToString(xContentBuilder); + SearchTopAnomalyResultRequest parsedRequest = SearchTopAnomalyResultRequest + .parse(TestHelpers.parser(requestAsXContentString), "test-detector-id", false); + assertEquals(taskId, parsedRequest.getTaskId()); + assertEquals((Integer) size, parsedRequest.getSize()); + assertEquals(categoryFields, parsedRequest.getCategoryFields()); + assertEquals(order, parsedRequest.getOrder()); + assertEquals(startTime.toEpochMilli(), parsedRequest.getStartTime().toEpochMilli()); + assertEquals(endTime.toEpochMilli(), parsedRequest.getEndTime().toEpochMilli()); + assertEquals(detectorId, parsedRequest.getId()); + assertEquals(historical, parsedRequest.getHistorical()); + } + + public void testNullTaskIdIsValid() { + SearchTopAnomalyResultRequest request = new SearchTopAnomalyResultRequest( + "test-detector-id", + null, + false, + 1, + Arrays.asList("test-field"), + SearchTopAnomalyResultTransportAction.OrderType.SEVERITY.getName(), + Instant.now().minus(10, ChronoUnit.DAYS), + Instant.now().minus(2, ChronoUnit.DAYS) + ); + ActionRequestValidationException exception = request.validate(); + Assert.assertNull(exception); + } + + public void testNullSizeIsValid() { + SearchTopAnomalyResultRequest request = new SearchTopAnomalyResultRequest( + "test-detector-id", + "", + false, + null, + Arrays.asList("test-field"), + SearchTopAnomalyResultTransportAction.OrderType.SEVERITY.getName(), + Instant.now().minus(10, ChronoUnit.DAYS), + Instant.now().minus(2, ChronoUnit.DAYS) + ); + ActionRequestValidationException exception = request.validate(); + Assert.assertNull(exception); + } + + public void testNullCategoryFieldIsValid() { + SearchTopAnomalyResultRequest request = new SearchTopAnomalyResultRequest( + "test-detector-id", + "", + false, + 1, + null, + SearchTopAnomalyResultTransportAction.OrderType.SEVERITY.getName(), + Instant.now().minus(10, ChronoUnit.DAYS), + Instant.now().minus(2, ChronoUnit.DAYS) + ); + ActionRequestValidationException exception = request.validate(); + Assert.assertNull(exception); + } + + public void testEmptyCategoryFieldIsValid() { + SearchTopAnomalyResultRequest request = new SearchTopAnomalyResultRequest( + "test-detector-id", + "", + false, + 1, + new ArrayList<>(), + SearchTopAnomalyResultTransportAction.OrderType.SEVERITY.getName(), + Instant.now().minus(10, ChronoUnit.DAYS), + Instant.now().minus(2, ChronoUnit.DAYS) + ); + ActionRequestValidationException exception = request.validate(); + Assert.assertNull(exception); + } + + public void testEmptyStartTimeIsInvalid() { + SearchTopAnomalyResultRequest request = new SearchTopAnomalyResultRequest( + "test-detector-id", + "", + false, + 1, + new ArrayList<>(), + SearchTopAnomalyResultTransportAction.OrderType.SEVERITY.getName(), + null, + Instant.now().minus(2, ChronoUnit.DAYS) + ); + ActionRequestValidationException exception = request.validate(); + Assert.assertNotNull(exception); + } + + public void testEmptyEndTimeIsInvalid() { + SearchTopAnomalyResultRequest request = new SearchTopAnomalyResultRequest( + "test-detector-id", + "", + false, + 1, + new ArrayList<>(), + SearchTopAnomalyResultTransportAction.OrderType.SEVERITY.getName(), + Instant.now().minus(10, ChronoUnit.DAYS), + null + ); + ActionRequestValidationException exception = request.validate(); + Assert.assertNotNull(exception); + } + + public void testEndTimeBeforeStartTimeIsInvalid() { + SearchTopAnomalyResultRequest request = new SearchTopAnomalyResultRequest( + "test-detector-id", + "", + false, + 1, + new ArrayList<>(), + SearchTopAnomalyResultTransportAction.OrderType.SEVERITY.getName(), + Instant.now().minus(2, ChronoUnit.DAYS), + Instant.now().minus(10, ChronoUnit.DAYS) + ); + ActionRequestValidationException exception = request.validate(); + Assert.assertNotNull(exception); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/SearchTopAnomalyResultResponseTests.java b/src/test/java/org/opensearch/ad/transport/SearchTopAnomalyResultResponseTests.java index 4f9181081..a0a08087e 100644 --- a/src/test/java/org/opensearch/ad/transport/SearchTopAnomalyResultResponseTests.java +++ b/src/test/java/org/opensearch/ad/transport/SearchTopAnomalyResultResponseTests.java @@ -10,7 +10,7 @@ import java.util.Arrays; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.timeseries.TestHelpers; diff --git a/src/test/java/org/opensearch/ad/transport/SearchTopAnomalyResultResponseTests.java-e b/src/test/java/org/opensearch/ad/transport/SearchTopAnomalyResultResponseTests.java-e new file mode 100644 index 000000000..a0a08087e --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/SearchTopAnomalyResultResponseTests.java-e @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.TestHelpers; + +public class SearchTopAnomalyResultResponseTests extends OpenSearchTestCase { + + public void testSerialization() throws IOException { + SearchTopAnomalyResultResponse originalResponse = new SearchTopAnomalyResultResponse( + Arrays.asList(TestHelpers.randomAnomalyResultBucket()) + ); + + BytesStreamOutput output = new BytesStreamOutput(); + originalResponse.writeTo(output); + StreamInput input = output.bytes().streamInput(); + SearchTopAnomalyResultResponse parsedResponse = new SearchTopAnomalyResultResponse(input); + assertEquals(originalResponse.getAnomalyResultBuckets(), parsedResponse.getAnomalyResultBuckets()); + } + + public void testEmptyResults() { + SearchTopAnomalyResultResponse response = new SearchTopAnomalyResultResponse(new ArrayList<>()); + } + + public void testPopulatedResults() { + SearchTopAnomalyResultResponse response = new SearchTopAnomalyResultResponse( + Arrays.asList(TestHelpers.randomAnomalyResultBucket()) + ); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/SearchTopAnomalyResultTransportActionTests.java-e b/src/test/java/org/opensearch/ad/transport/SearchTopAnomalyResultTransportActionTests.java-e new file mode 100644 index 000000000..969dc4523 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/SearchTopAnomalyResultTransportActionTests.java-e @@ -0,0 +1,361 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.transport; + +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.stubbing.Answer; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.ADIntegTestCase; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.AnomalyResultBucket; +import org.opensearch.ad.transport.handler.ADSearchHandler; +import org.opensearch.client.Client; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; +import org.opensearch.search.aggregations.metrics.InternalMax; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableList; + +public class SearchTopAnomalyResultTransportActionTests extends ADIntegTestCase { + private SearchTopAnomalyResultTransportAction action; + + // Helper method to generate the Aggregations obj using the list of result buckets + private Aggregations generateAggregationsFromBuckets(List buckets, Map mockAfterKeyValue) { + List bucketList = new ArrayList<>(); + + for (AnomalyResultBucket bucket : buckets) { + InternalMax maxGradeAgg = mock(InternalMax.class); + when(maxGradeAgg.getName()).thenReturn(AnomalyResultBucket.MAX_ANOMALY_GRADE_FIELD); + when(maxGradeAgg.getValue()).thenReturn(bucket.getMaxAnomalyGrade()); + CompositeAggregation.Bucket aggBucket = mock(CompositeAggregation.Bucket.class); + when(aggBucket.getKey()).thenReturn(bucket.getKey()); + when(aggBucket.getDocCount()).thenReturn((long) bucket.getDocCount()); + when(aggBucket.getAggregations()).thenReturn(new Aggregations(new ArrayList() { + { + add(maxGradeAgg); + } + })); + bucketList.add(aggBucket); + } + + CompositeAggregation composite = mock(CompositeAggregation.class); + when(composite.getName()).thenReturn(SearchTopAnomalyResultTransportAction.MULTI_BUCKETS_FIELD); + when(composite.getBuckets()).thenAnswer((Answer>) invocation -> bucketList); + when(composite.afterKey()).thenReturn(mockAfterKeyValue); + + List aggList = Collections.singletonList(composite); + return new Aggregations(aggList); + } + + // Helper method to generate a SearchResponse obj using the given aggs + private SearchResponse generateMockSearchResponse(Aggregations aggs) { + SearchResponseSections sections = new SearchResponseSections(SearchHits.empty(), aggs, null, false, null, null, 1); + return new SearchResponse(sections, null, 1, 1, 0, 0, ShardSearchFailure.EMPTY_ARRAY, SearchResponse.Clusters.EMPTY); + } + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + action = new SearchTopAnomalyResultTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + mock(ADSearchHandler.class), + mock(Client.class) + ); + } + + public void testSearchOnNonExistingResultIndex() throws IOException { + deleteIndexIfExists(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS); + String testIndexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT); + ImmutableList categoryFields = ImmutableList.of("test-field-1", "test-field-2"); + String detectorId = createDetector( + TestHelpers + .randomAnomalyDetector( + ImmutableList.of(testIndexName), + ImmutableList.of(TestHelpers.randomFeature(true)), + null, + Instant.now(), + 1, + false, + categoryFields + ) + ); + SearchTopAnomalyResultRequest searchRequest = new SearchTopAnomalyResultRequest( + detectorId, + null, + false, + 1, + Arrays.asList(categoryFields.get(0)), + SearchTopAnomalyResultTransportAction.OrderType.SEVERITY.getName(), + Instant.now().minus(10, ChronoUnit.DAYS), + Instant.now() + ); + SearchTopAnomalyResultResponse searchResponse = client() + .execute(SearchTopAnomalyResultAction.INSTANCE, searchRequest) + .actionGet(10_000); + assertEquals(searchResponse.getAnomalyResultBuckets().size(), 0); + } + + @SuppressWarnings("unchecked") + public void testListenerWithNullResult() { + ActionListener mockListener = mock(ActionListener.class); + SearchTopAnomalyResultTransportAction.TopAnomalyResultListener listener = action.new TopAnomalyResultListener( + mockListener, new SearchSourceBuilder(), 1000, 10, SearchTopAnomalyResultTransportAction.OrderType.SEVERITY, + "custom-result-index-name" + ); + ArgumentCaptor failureCaptor = ArgumentCaptor.forClass(Exception.class); + + listener.onResponse(null); + + verify(mockListener, times(1)).onFailure(failureCaptor.capture()); + assertTrue(failureCaptor.getValue() != null); + } + + @SuppressWarnings("unchecked") + public void testListenerWithNullAggregation() { + ActionListener mockListener = mock(ActionListener.class); + SearchTopAnomalyResultTransportAction.TopAnomalyResultListener listener = action.new TopAnomalyResultListener( + mockListener, new SearchSourceBuilder(), 1000, 10, SearchTopAnomalyResultTransportAction.OrderType.SEVERITY, + "custom-result-index-name" + ); + + SearchResponse response = generateMockSearchResponse(null); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(SearchTopAnomalyResultResponse.class); + + listener.onResponse(response); + + verify(mockListener, times(1)).onResponse(responseCaptor.capture()); + SearchTopAnomalyResultResponse capturedResponse = responseCaptor.getValue(); + assertTrue(capturedResponse != null); + assertTrue(capturedResponse.getAnomalyResultBuckets() != null); + assertEquals(0, capturedResponse.getAnomalyResultBuckets().size()); + } + + @SuppressWarnings("unchecked") + public void testListenerWithInvalidAggregation() { + ActionListener mockListener = mock(ActionListener.class); + SearchTopAnomalyResultTransportAction.TopAnomalyResultListener listener = action.new TopAnomalyResultListener( + mockListener, new SearchSourceBuilder(), 1000, 10, SearchTopAnomalyResultTransportAction.OrderType.SEVERITY, + "custom-result-index-name" + ); + + // an empty list won't have an entry for 'MULTI_BUCKETS_FIELD' as needed to parse out + // the expected result buckets, and thus should fail + Aggregations aggs = new Aggregations(new ArrayList<>()); + SearchResponse response = generateMockSearchResponse(aggs); + ArgumentCaptor failureCaptor = ArgumentCaptor.forClass(Exception.class); + + listener.onResponse(response); + + verify(mockListener, times(1)).onFailure(failureCaptor.capture()); + assertTrue(failureCaptor.getValue() != null); + } + + @SuppressWarnings("unchecked") + public void testListenerWithValidEmptyAggregation() { + ActionListener mockListener = mock(ActionListener.class); + SearchTopAnomalyResultTransportAction.TopAnomalyResultListener listener = action.new TopAnomalyResultListener( + mockListener, new SearchSourceBuilder(), 1000, 10, SearchTopAnomalyResultTransportAction.OrderType.SEVERITY, + "custom-result-index-name" + ); + + CompositeAggregation composite = mock(CompositeAggregation.class); + when(composite.getName()).thenReturn(SearchTopAnomalyResultTransportAction.MULTI_BUCKETS_FIELD); + when(composite.getBuckets()).thenReturn(new ArrayList<>()); + when(composite.afterKey()).thenReturn(null); + List aggList = Collections.singletonList(composite); + Aggregations aggs = new Aggregations(aggList); + + SearchResponse response = generateMockSearchResponse(aggs); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(SearchTopAnomalyResultResponse.class); + + listener.onResponse(response); + + verify(mockListener, times(1)).onResponse(responseCaptor.capture()); + SearchTopAnomalyResultResponse capturedResponse = responseCaptor.getValue(); + assertTrue(capturedResponse != null); + assertTrue(capturedResponse.getAnomalyResultBuckets() != null); + assertEquals(0, capturedResponse.getAnomalyResultBuckets().size()); + } + + @SuppressWarnings("unchecked") + public void testListenerTimesOutWithNoResults() { + ActionListener mockListener = mock(ActionListener.class); + SearchTopAnomalyResultTransportAction.TopAnomalyResultListener listener = action.new TopAnomalyResultListener( + mockListener, new SearchSourceBuilder(), 1000, // this is guaranteed to be an expired timestamp + 10, SearchTopAnomalyResultTransportAction.OrderType.OCCURRENCE, "custom-result-index-name" + ); + + Aggregations aggs = generateAggregationsFromBuckets(new ArrayList<>(), new HashMap() { + { + put("category-field-name-1", "value-2"); + } + }); + SearchResponse response = generateMockSearchResponse(aggs); + ArgumentCaptor failureCaptor = ArgumentCaptor.forClass(Exception.class); + + listener.onResponse(response); + + verify(mockListener, times(1)).onFailure(failureCaptor.capture()); + assertTrue(failureCaptor.getValue() != null); + } + + @SuppressWarnings("unchecked") + public void testListenerTimesOutWithPartialResults() { + ActionListener mockListener = mock(ActionListener.class); + SearchTopAnomalyResultTransportAction.TopAnomalyResultListener listener = action.new TopAnomalyResultListener( + mockListener, new SearchSourceBuilder(), 1000, // this is guaranteed to be an expired timestamp + 10, SearchTopAnomalyResultTransportAction.OrderType.OCCURRENCE, "custom-result-index-name" + ); + + AnomalyResultBucket expectedResponseBucket1 = new AnomalyResultBucket(new HashMap() { + { + put("category-field-name-1", "value-1"); + } + }, 5, 0.2); + + Aggregations aggs = generateAggregationsFromBuckets(new ArrayList() { + { + add(expectedResponseBucket1); + } + }, new HashMap() { + { + put("category-field-name-1", "value-2"); + } + }); + + SearchResponse response = generateMockSearchResponse(aggs); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(SearchTopAnomalyResultResponse.class); + + listener.onResponse(response); + + verify(mockListener, times(1)).onResponse(responseCaptor.capture()); + SearchTopAnomalyResultResponse capturedResponse = responseCaptor.getValue(); + assertTrue(capturedResponse != null); + assertTrue(capturedResponse.getAnomalyResultBuckets() != null); + assertEquals(1, capturedResponse.getAnomalyResultBuckets().size()); + assertEquals(expectedResponseBucket1, capturedResponse.getAnomalyResultBuckets().get(0)); + } + + @SuppressWarnings("unchecked") + public void testListenerSortingBySeverity() { + ActionListener mockListener = mock(ActionListener.class); + SearchTopAnomalyResultTransportAction.TopAnomalyResultListener listener = action.new TopAnomalyResultListener( + mockListener, new SearchSourceBuilder(), 1000, 10, SearchTopAnomalyResultTransportAction.OrderType.SEVERITY, + "custom-result-index-name" + ); + + AnomalyResultBucket expectedResponseBucket1 = new AnomalyResultBucket(new HashMap() { + { + put("category-field-name-1", "value-1"); + } + }, 5, 0.2); + AnomalyResultBucket expectedResponseBucket2 = new AnomalyResultBucket(new HashMap() { + { + put("category-field-name-1", "value-2"); + } + }, 5, 0.3); + AnomalyResultBucket expectedResponseBucket3 = new AnomalyResultBucket(new HashMap() { + { + put("category-field-name-1", "value-3"); + } + }, 5, 0.1); + + Aggregations aggs = generateAggregationsFromBuckets(new ArrayList() { + { + add(expectedResponseBucket1); + add(expectedResponseBucket2); + add(expectedResponseBucket3); + } + }, null); + + SearchResponse response = generateMockSearchResponse(aggs); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(SearchTopAnomalyResultResponse.class); + + listener.onResponse(response); + + verify(mockListener, times(1)).onResponse(responseCaptor.capture()); + SearchTopAnomalyResultResponse capturedResponse = responseCaptor.getValue(); + assertTrue(capturedResponse != null); + assertTrue(capturedResponse.getAnomalyResultBuckets() != null); + assertEquals(3, capturedResponse.getAnomalyResultBuckets().size()); + assertEquals(expectedResponseBucket2, capturedResponse.getAnomalyResultBuckets().get(0)); + assertEquals(expectedResponseBucket1, capturedResponse.getAnomalyResultBuckets().get(1)); + assertEquals(expectedResponseBucket3, capturedResponse.getAnomalyResultBuckets().get(2)); + } + + @SuppressWarnings("unchecked") + public void testListenerSortingByOccurrence() { + ActionListener mockListener = mock(ActionListener.class); + SearchTopAnomalyResultTransportAction.TopAnomalyResultListener listener = action.new TopAnomalyResultListener( + mockListener, new SearchSourceBuilder(), 1000, 10, SearchTopAnomalyResultTransportAction.OrderType.OCCURRENCE, + "custom-result-index-name" + ); + + AnomalyResultBucket expectedResponseBucket1 = new AnomalyResultBucket(new HashMap() { + { + put("category-field-name-1", "value-1"); + } + }, 2, 0.5); + AnomalyResultBucket expectedResponseBucket2 = new AnomalyResultBucket(new HashMap() { + { + put("category-field-name-1", "value-2"); + } + }, 3, 0.5); + AnomalyResultBucket expectedResponseBucket3 = new AnomalyResultBucket(new HashMap() { + { + put("category-field-name-1", "value-3"); + } + }, 1, 0.5); + + Aggregations aggs = generateAggregationsFromBuckets(new ArrayList() { + { + add(expectedResponseBucket1); + add(expectedResponseBucket2); + add(expectedResponseBucket3); + } + }, null); + + SearchResponse response = generateMockSearchResponse(aggs); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(SearchTopAnomalyResultResponse.class); + + listener.onResponse(response); + + verify(mockListener, times(1)).onResponse(responseCaptor.capture()); + SearchTopAnomalyResultResponse capturedResponse = responseCaptor.getValue(); + assertTrue(capturedResponse != null); + assertTrue(capturedResponse.getAnomalyResultBuckets() != null); + assertEquals(3, capturedResponse.getAnomalyResultBuckets().size()); + assertEquals(expectedResponseBucket2, capturedResponse.getAnomalyResultBuckets().get(0)); + assertEquals(expectedResponseBucket1, capturedResponse.getAnomalyResultBuckets().get(1)); + assertEquals(expectedResponseBucket3, capturedResponse.getAnomalyResultBuckets().get(2)); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorActionTests.java b/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorActionTests.java index 53e498164..796d492e1 100644 --- a/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorActionTests.java @@ -24,8 +24,8 @@ import org.opensearch.ad.stats.ADStatsResponse; import org.opensearch.cluster.ClusterName; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; diff --git a/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorActionTests.java-e b/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorActionTests.java-e new file mode 100644 index 000000000..74f7a6a12 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorActionTests.java-e @@ -0,0 +1,71 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.FailedNodeException; +import org.opensearch.ad.stats.ADStatsResponse; +import org.opensearch.cluster.ClusterName; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.test.OpenSearchTestCase; + +public class StatsAnomalyDetectorActionTests extends OpenSearchTestCase { + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + } + + @Test + public void testStatsAction() { + Assert.assertNotNull(StatsAnomalyDetectorAction.INSTANCE.name()); + Assert.assertEquals(StatsAnomalyDetectorAction.INSTANCE.name(), StatsAnomalyDetectorAction.NAME); + } + + @Test + public void testStatsResponse() throws IOException { + ADStatsResponse adStatsResponse = new ADStatsResponse(); + Map testClusterStats = new HashMap<>(); + testClusterStats.put("test_response", 1); + adStatsResponse.setClusterStats(testClusterStats); + List responses = Collections.emptyList(); + List failures = Collections.emptyList(); + ADStatsNodesResponse adStatsNodesResponse = new ADStatsNodesResponse(ClusterName.DEFAULT, responses, failures); + adStatsResponse.setADStatsNodesResponse(adStatsNodesResponse); + + StatsAnomalyDetectorResponse response = new StatsAnomalyDetectorResponse(adStatsResponse); + BytesStreamOutput out = new BytesStreamOutput(); + response.writeTo(out); + StreamInput input = out.bytes().streamInput(); + StatsAnomalyDetectorResponse newResponse = new StatsAnomalyDetectorResponse(input); + assertNotNull(newResponse); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder = newResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); + XContentParser parser = createParser(builder); + assertEquals(1, parser.map().get("test_response")); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportActionTests.java-e b/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportActionTests.java-e new file mode 100644 index 000000000..7c877c086 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportActionTests.java-e @@ -0,0 +1,102 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.time.Instant; +import java.util.Map; + +import org.junit.Before; +import org.opensearch.ad.ADIntegTestCase; +import org.opensearch.ad.stats.InternalStatNames; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.stats.StatNames; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +public class StatsAnomalyDetectorTransportActionTests extends ADIntegTestCase { + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + createDetectors( + ImmutableList + .of( + TestHelpers.randomAnomalyDetector(ImmutableMap.of(), Instant.now()), + TestHelpers + .randomAnomalyDetector( + ImmutableList.of(TestHelpers.randomFeature()), + ImmutableMap.of(), + Instant.now(), + true, + ImmutableList.of(randomAlphaOfLength(5)) + ) + ), + true + ); + } + + public void testStatsAnomalyDetectorWithNodeLevelStats() { + ADStatsRequest adStatsRequest = new ADStatsRequest(clusterService().localNode()); + adStatsRequest.addStat(InternalStatNames.JVM_HEAP_USAGE.getName()); + StatsAnomalyDetectorResponse response = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5_000); + assertEquals(1, response.getAdStatsResponse().getADStatsNodesResponse().getNodes().size()); + assertTrue( + response + .getAdStatsResponse() + .getADStatsNodesResponse() + .getNodes() + .get(0) + .getStatsMap() + .containsKey(InternalStatNames.JVM_HEAP_USAGE.getName()) + ); + } + + public void testStatsAnomalyDetectorWithClusterLevelStats() { + ADStatsRequest adStatsRequest = new ADStatsRequest(clusterService().localNode()); + adStatsRequest.addStat(StatNames.DETECTOR_COUNT.getName()); + adStatsRequest.addStat(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName()); + StatsAnomalyDetectorResponse response = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5_000); + assertEquals(1, response.getAdStatsResponse().getADStatsNodesResponse().getNodes().size()); + Map statsMap = response.getAdStatsResponse().getADStatsNodesResponse().getNodes().get(0).getStatsMap(); + Map clusterStats = response.getAdStatsResponse().getClusterStats(); + assertEquals(0, statsMap.size()); + assertEquals(2L, clusterStats.get(StatNames.DETECTOR_COUNT.getName())); + assertEquals(1L, clusterStats.get(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName())); + } + + public void testStatsAnomalyDetectorWithDetectorCount() { + ADStatsRequest adStatsRequest = new ADStatsRequest(clusterService().localNode()); + adStatsRequest.addStat(StatNames.DETECTOR_COUNT.getName()); + StatsAnomalyDetectorResponse response = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5_000); + assertEquals(1, response.getAdStatsResponse().getADStatsNodesResponse().getNodes().size()); + Map statsMap = response.getAdStatsResponse().getADStatsNodesResponse().getNodes().get(0).getStatsMap(); + Map clusterStats = response.getAdStatsResponse().getClusterStats(); + assertEquals(0, statsMap.size()); + assertEquals(2L, clusterStats.get(StatNames.DETECTOR_COUNT.getName())); + assertFalse(clusterStats.containsKey(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName())); + } + + public void testStatsAnomalyDetectorWithSingleEntityDetectorCount() { + ADStatsRequest adStatsRequest = new ADStatsRequest(clusterService().localNode()); + adStatsRequest.addStat(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName()); + StatsAnomalyDetectorResponse response = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5_000); + assertEquals(1, response.getAdStatsResponse().getADStatsNodesResponse().getNodes().size()); + Map statsMap = response.getAdStatsResponse().getADStatsNodesResponse().getNodes().get(0).getStatsMap(); + Map clusterStats = response.getAdStatsResponse().getClusterStats(); + assertEquals(0, statsMap.size()); + assertEquals(1L, clusterStats.get(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName())); + assertFalse(clusterStats.containsKey(StatNames.DETECTOR_COUNT.getName())); + } + +} diff --git a/src/test/java/org/opensearch/ad/transport/StopDetectorActionTests.java b/src/test/java/org/opensearch/ad/transport/StopDetectorActionTests.java index c90449538..d6ed84d2d 100644 --- a/src/test/java/org/opensearch/ad/transport/StopDetectorActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/StopDetectorActionTests.java @@ -21,9 +21,9 @@ import org.opensearch.action.ActionResponse; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.test.OpenSearchIntegTestCase; diff --git a/src/test/java/org/opensearch/ad/transport/StopDetectorActionTests.java-e b/src/test/java/org/opensearch/ad/transport/StopDetectorActionTests.java-e new file mode 100644 index 000000000..ad0d61e3d --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/StopDetectorActionTests.java-e @@ -0,0 +1,100 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.ActionResponse; +import org.opensearch.common.Strings; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.test.OpenSearchIntegTestCase; + +public class StopDetectorActionTests extends OpenSearchIntegTestCase { + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + } + + @Test + public void testStopDetectorAction() { + Assert.assertNotNull(StopDetectorAction.INSTANCE.name()); + Assert.assertEquals(StopDetectorAction.INSTANCE.name(), StopDetectorAction.NAME); + } + + @Test + public void fromActionRequest_Success() { + StopDetectorRequest stopDetectorRequest = new StopDetectorRequest("adID"); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + stopDetectorRequest.writeTo(out); + } + }; + StopDetectorRequest result = StopDetectorRequest.fromActionRequest(actionRequest); + assertNotSame(result, stopDetectorRequest); + assertEquals(result.getAdID(), stopDetectorRequest.getAdID()); + } + + @Test + public void writeTo_Success() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + StopDetectorResponse response = new StopDetectorResponse(true); + response.writeTo(bytesStreamOutput); + StopDetectorResponse parsedResponse = new StopDetectorResponse(bytesStreamOutput.bytes().streamInput()); + assertNotEquals(response, parsedResponse); + assertEquals(response.success(), parsedResponse.success()); + } + + @Test + public void fromActionResponse_Success() throws IOException { + StopDetectorResponse stopDetectorResponse = new StopDetectorResponse(true); + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput streamOutput) throws IOException { + stopDetectorResponse.writeTo(streamOutput); + } + }; + StopDetectorResponse result = stopDetectorResponse.fromActionResponse(actionResponse); + assertNotSame(result, stopDetectorResponse); + assertEquals(result.success(), stopDetectorResponse.success()); + + StopDetectorResponse parsedStopDetectorResponse = stopDetectorResponse.fromActionResponse(stopDetectorResponse); + assertEquals(parsedStopDetectorResponse, stopDetectorResponse); + } + + @Test + public void toXContentTest() throws IOException { + StopDetectorResponse stopDetectorResponse = new StopDetectorResponse(true); + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + stopDetectorResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + String jsonStr = Strings.toString(builder); + assertEquals("{\"success\":true}", jsonStr); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ThresholdResultITTests.java b/src/test/java/org/opensearch/ad/transport/ThresholdResultITTests.java index ae84f54f5..779ef8c8b 100644 --- a/src/test/java/org/opensearch/ad/transport/ThresholdResultITTests.java +++ b/src/test/java/org/opensearch/ad/transport/ThresholdResultITTests.java @@ -17,19 +17,19 @@ import org.opensearch.action.ActionFuture; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; public class ThresholdResultITTests extends OpenSearchIntegTestCase { @Override protected Collection> nodePlugins() { - return Collections.singletonList(AnomalyDetectorPlugin.class); + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); } protected Collection> transportClientPlugins() { - return Collections.singletonList(AnomalyDetectorPlugin.class); + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); } public void testEmptyID() throws ExecutionException, InterruptedException { diff --git a/src/test/java/org/opensearch/ad/transport/ThresholdResultITTests.java-e b/src/test/java/org/opensearch/ad/transport/ThresholdResultITTests.java-e new file mode 100644 index 000000000..779ef8c8b --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ThresholdResultITTests.java-e @@ -0,0 +1,50 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.util.Collection; +import java.util.Collections; +import java.util.concurrent.ExecutionException; + +import org.opensearch.action.ActionFuture; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.plugins.Plugin; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +public class ThresholdResultITTests extends OpenSearchIntegTestCase { + + @Override + protected Collection> nodePlugins() { + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); + } + + protected Collection> transportClientPlugins() { + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); + } + + public void testEmptyID() throws ExecutionException, InterruptedException { + ThresholdResultRequest request = new ThresholdResultRequest("", "123-threshold", 2.5d); + + ActionFuture future = client().execute(ThresholdResultAction.INSTANCE, request); + + expectThrows(ActionRequestValidationException.class, () -> future.actionGet()); + } + + public void testIDIsNull() throws ExecutionException, InterruptedException { + ThresholdResultRequest request = new ThresholdResultRequest(null, "123-threshold", 2.5d); + + ActionFuture future = client().execute(ThresholdResultAction.INSTANCE, request); + + expectThrows(ActionRequestValidationException.class, () -> future.actionGet()); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ThresholdResultTests.java b/src/test/java/org/opensearch/ad/transport/ThresholdResultTests.java index fdea0d340..9f2869c8c 100644 --- a/src/test/java/org/opensearch/ad/transport/ThresholdResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/ThresholdResultTests.java @@ -34,8 +34,8 @@ import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.tasks.Task; diff --git a/src/test/java/org/opensearch/ad/transport/ThresholdResultTests.java-e b/src/test/java/org/opensearch/ad/transport/ThresholdResultTests.java-e new file mode 100644 index 000000000..803acd9ac --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ThresholdResultTests.java-e @@ -0,0 +1,155 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyDouble; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; + +import java.io.IOException; +import java.util.Collections; + +import org.hamcrest.Matchers; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.ad.common.exception.JsonPathNotFoundException; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.common.Strings; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.Transport; +import org.opensearch.transport.TransportService; + +import test.org.opensearch.ad.util.JsonDeserializer; + +public class ThresholdResultTests extends OpenSearchTestCase { + + @SuppressWarnings("unchecked") + public void testNormal() { + TransportService transportService = new TransportService( + Settings.EMPTY, + mock(Transport.class), + null, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + x -> null, + null, + Collections.emptySet() + ); + + ModelManager manager = mock(ModelManager.class); + ThresholdResultTransportAction action = new ThresholdResultTransportAction(mock(ActionFilters.class), transportService, manager); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(new ThresholdingResult(0, 1.0d, 0.2)); + return null; + }).when(manager).getThresholdingResult(any(String.class), any(String.class), anyDouble(), any(ActionListener.class)); + + final PlainActionFuture future = new PlainActionFuture<>(); + ThresholdResultRequest request = new ThresholdResultRequest("123", "123-threshold", 2); + action.doExecute(mock(Task.class), request, future); + + ThresholdResultResponse response = future.actionGet(); + assertEquals(0, response.getAnomalyGrade(), 0.001); + assertEquals(1, response.getConfidence(), 0.001); + } + + @SuppressWarnings("unchecked") + public void testExecutionException() { + TransportService transportService = new TransportService( + Settings.EMPTY, + mock(Transport.class), + null, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + x -> null, + null, + Collections.emptySet() + ); + + ModelManager manager = mock(ModelManager.class); + ThresholdResultTransportAction action = new ThresholdResultTransportAction(mock(ActionFilters.class), transportService, manager); + doThrow(NullPointerException.class) + .when(manager) + .getThresholdingResult(any(String.class), any(String.class), anyDouble(), any(ActionListener.class)); + + final PlainActionFuture future = new PlainActionFuture<>(); + ThresholdResultRequest request = new ThresholdResultRequest("123", "123-threshold", 2); + action.doExecute(mock(Task.class), request, future); + + expectThrows(NullPointerException.class, () -> future.actionGet()); + } + + public void testSerialzationResponse() throws IOException { + ThresholdResultResponse response = new ThresholdResultResponse(1, 0.8); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + ThresholdResultResponse readResponse = ThresholdResultAction.INSTANCE.getResponseReader().read(streamInput); + assertThat(response.getAnomalyGrade(), equalTo(readResponse.getAnomalyGrade())); + assertThat(response.getConfidence(), equalTo(readResponse.getConfidence())); + } + + public void testJsonResponse() throws IOException, JsonPathNotFoundException { + ThresholdResultResponse response = new ThresholdResultResponse(1, 0.8); + XContentBuilder builder = jsonBuilder(); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String json = Strings.toString(builder); + assertEquals(JsonDeserializer.getDoubleValue(json, ADCommonName.ANOMALY_GRADE_JSON_KEY), response.getAnomalyGrade(), 0.001); + assertEquals(JsonDeserializer.getDoubleValue(json, ADCommonName.CONFIDENCE_JSON_KEY), response.getConfidence(), 0.001); + } + + public void testEmptyDetectorID() { + ActionRequestValidationException e = new ThresholdResultRequest(null, "123-threshold", 2).validate(); + assertThat(e.validationErrors(), Matchers.hasItem(ADCommonMessages.AD_ID_MISSING_MSG)); + } + + public void testEmptyModelID() { + ActionRequestValidationException e = new ThresholdResultRequest("123", "", 2).validate(); + assertThat(e.validationErrors(), Matchers.hasItem(ADCommonMessages.MODEL_ID_MISSING_MSG)); + } + + public void testSerialzationRequest() throws IOException { + ThresholdResultRequest response = new ThresholdResultRequest("123", "123-threshold", 2); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + ThresholdResultRequest readResponse = new ThresholdResultRequest(streamInput); + assertThat(response.getAdID(), equalTo(readResponse.getAdID())); + assertThat(response.getRCFScore(), equalTo(readResponse.getRCFScore())); + } + + public void testJsonRequest() throws IOException, JsonPathNotFoundException { + ThresholdResultRequest request = new ThresholdResultRequest("123", "123-threshold", 2); + XContentBuilder builder = jsonBuilder(); + request.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String json = Strings.toString(builder); + assertEquals(JsonDeserializer.getTextValue(json, ADCommonName.ID_JSON_KEY), request.getAdID()); + assertEquals(JsonDeserializer.getDoubleValue(json, ADCommonName.RCF_SCORE_JSON_KEY), request.getRCFScore(), 0.001); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorActionTests.java-e b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorActionTests.java-e new file mode 100644 index 000000000..8270ef965 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorActionTests.java-e @@ -0,0 +1,23 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.junit.Assert; +import org.junit.Test; + +public class ValidateAnomalyDetectorActionTests { + @Test + public void testValidateAnomalyDetectorActionTests() { + Assert.assertNotNull(ValidateAnomalyDetectorAction.INSTANCE.name()); + Assert.assertEquals(ValidateAnomalyDetectorAction.INSTANCE.name(), ValidateAnomalyDetectorAction.NAME); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequestTests.java b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequestTests.java index 135a334b7..4a1fae9cb 100644 --- a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequestTests.java +++ b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequestTests.java @@ -17,9 +17,9 @@ import org.junit.Test; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; -import org.opensearch.common.io.stream.NamedWriteableRegistry; import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.timeseries.TestHelpers; diff --git a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequestTests.java-e b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequestTests.java-e new file mode 100644 index 000000000..5b9201391 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequestTests.java-e @@ -0,0 +1,52 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.time.Instant; + +import org.junit.Test; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.timeseries.TestHelpers; + +import com.google.common.collect.ImmutableMap; + +public class ValidateAnomalyDetectorRequestTests extends OpenSearchSingleNodeTestCase { + + @Override + protected NamedWriteableRegistry writableRegistry() { + return getInstanceFromNode(NamedWriteableRegistry.class); + } + + @Test + public void testValidateAnomalyDetectorRequestSerialization() throws IOException { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of("testKey", "testValue"), Instant.now()); + TimeValue requestTimeout = new TimeValue(1000L); + String typeStr = "type"; + + ValidateAnomalyDetectorRequest request1 = new ValidateAnomalyDetectorRequest(detector, typeStr, 1, 1, 1, requestTimeout); + + // Test serialization + BytesStreamOutput output = new BytesStreamOutput(); + request1.writeTo(output); + NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + ValidateAnomalyDetectorRequest request2 = new ValidateAnomalyDetectorRequest(input); + assertEquals("serialization has the wrong detector", request2.getDetector(), detector); + assertEquals("serialization has the wrong typeStr", request2.getValidationType(), typeStr); + assertEquals("serialization has the wrong requestTimeout", request2.getRequestTimeout(), requestTimeout); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponseTests.java b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponseTests.java index 533bfd52d..6c52634d0 100644 --- a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponseTests.java +++ b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponseTests.java @@ -19,7 +19,7 @@ import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.model.DetectorValidationIssue; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; diff --git a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponseTests.java-e b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponseTests.java-e new file mode 100644 index 000000000..6c52634d0 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorResponseTests.java-e @@ -0,0 +1,105 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Test; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.model.DetectorValidationIssue; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; + +public class ValidateAnomalyDetectorResponseTests extends AbstractTimeSeriesTest { + + @Test + public void testResponseSerialization() throws IOException { + Map subIssues = new HashMap<>(); + subIssues.put("a", "b"); + subIssues.put("c", "d"); + DetectorValidationIssue issue = TestHelpers.randomDetectorValidationIssueWithSubIssues(subIssues); + ValidateAnomalyDetectorResponse response = new ValidateAnomalyDetectorResponse(issue); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + ValidateAnomalyDetectorResponse readResponse = ValidateAnomalyDetectorAction.INSTANCE.getResponseReader().read(streamInput); + assertEquals("serialization has the wrong issue", issue, readResponse.getIssue()); + } + + @Test + public void testResponseSerializationWithEmptyIssue() throws IOException { + ValidateAnomalyDetectorResponse response = new ValidateAnomalyDetectorResponse((DetectorValidationIssue) null); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + ValidateAnomalyDetectorResponse readResponse = ValidateAnomalyDetectorAction.INSTANCE.getResponseReader().read(streamInput); + assertNull("serialization should have empty issue", readResponse.getIssue()); + } + + public void testResponseToXContentWithSubIssues() throws IOException { + Map subIssues = new HashMap<>(); + subIssues.put("a", "b"); + subIssues.put("c", "d"); + DetectorValidationIssue issue = TestHelpers.randomDetectorValidationIssueWithSubIssues(subIssues); + ValidateAnomalyDetectorResponse response = new ValidateAnomalyDetectorResponse(issue); + String validationResponse = TestHelpers.xContentBuilderToString(response.toXContent(TestHelpers.builder())); + String message = issue.getMessage(); + assertEquals( + "{\"detector\":{\"name\":{\"message\":\"" + message + "\",\"sub_issues\":{\"a\":\"b\",\"c\":\"d\"}}}}", + validationResponse + ); + } + + public void testResponseToXContent() throws IOException { + DetectorValidationIssue issue = TestHelpers.randomDetectorValidationIssue(); + ValidateAnomalyDetectorResponse response = new ValidateAnomalyDetectorResponse(issue); + String validationResponse = TestHelpers.xContentBuilderToString(response.toXContent(TestHelpers.builder())); + String message = issue.getMessage(); + assertEquals("{\"detector\":{\"name\":{\"message\":\"" + message + "\"}}}", validationResponse); + } + + public void testResponseToXContentNull() throws IOException { + ValidateAnomalyDetectorResponse response = new ValidateAnomalyDetectorResponse((DetectorValidationIssue) null); + String validationResponse = TestHelpers.xContentBuilderToString(response.toXContent(TestHelpers.builder())); + assertEquals("{}", validationResponse); + } + + public void testResponseToXContentWithIntervalRec() throws IOException { + long intervalRec = 5; + DetectorValidationIssue issue = TestHelpers.randomDetectorValidationIssueWithDetectorIntervalRec(intervalRec); + ValidateAnomalyDetectorResponse response = new ValidateAnomalyDetectorResponse(issue); + String validationResponse = TestHelpers.xContentBuilderToString(response.toXContent(TestHelpers.builder())); + assertEquals( + "{\"model\":{\"detection_interval\":{\"message\":\"" + + ADCommonMessages.DETECTOR_INTERVAL_REC + + intervalRec + + "\",\"suggested_value\":{\"period\":{\"interval\":5,\"unit\":\"Minutes\"}}}}}", + validationResponse + ); + } + + @Test + public void testResponseSerializationWithIntervalRec() throws IOException { + long intervalRec = 5; + DetectorValidationIssue issue = TestHelpers.randomDetectorValidationIssueWithDetectorIntervalRec(intervalRec); + ValidateAnomalyDetectorResponse response = new ValidateAnomalyDetectorResponse(issue); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + ValidateAnomalyDetectorResponse readResponse = ValidateAnomalyDetectorAction.INSTANCE.getResponseReader().read(streamInput); + assertEquals(issue, readResponse.getIssue()); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportActionTests.java-e b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportActionTests.java-e new file mode 100644 index 000000000..604fc2c46 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportActionTests.java-e @@ -0,0 +1,470 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.io.IOException; +import java.net.URL; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Locale; + +import org.junit.Test; +import org.opensearch.ad.ADIntegTestCase; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + +import com.google.common.base.Charsets; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.Resources; + +public class ValidateAnomalyDetectorTransportActionTests extends ADIntegTestCase { + + @Test + public void testValidateAnomalyDetectorWithNoIssue() throws IOException { + AnomalyDetector anomalyDetector = TestHelpers + .randomAnomalyDetector(timeField, "test-index", ImmutableList.of(sumValueFeature(nameField, ipField + ".is_error", "test-2"))); + ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); + ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + anomalyDetector, + ValidationAspect.DETECTOR.getName(), + 5, + 5, + 5, + new TimeValue(5_000L) + ); + ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + assertNull(response.getIssue()); + } + + @Test + public void testValidateAnomalyDetectorWithNoIndexFound() throws IOException { + AnomalyDetector anomalyDetector = TestHelpers.randomAnomalyDetector(ImmutableMap.of(), Instant.now()); + ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + anomalyDetector, + ValidationAspect.DETECTOR.getName(), + 5, + 5, + 5, + new TimeValue(5_000L) + ); + ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + assertNotNull(response.getIssue()); + assertEquals(ValidationIssueType.INDICES, response.getIssue().getType()); + assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); + assertTrue(response.getIssue().getMessage().contains(ADCommonMessages.INDEX_NOT_FOUND)); + } + + @Test + public void testValidateAnomalyDetectorWithDuplicateName() throws IOException { + AnomalyDetector anomalyDetector = TestHelpers.randomAnomalyDetector(timeField, "index-test"); + ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); + createDetectorIndex(); + createDetector(anomalyDetector); + ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + anomalyDetector, + ValidationAspect.DETECTOR.getName(), + 5, + 5, + 5, + new TimeValue(5_000L) + ); + ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + assertNotNull(response.getIssue()); + assertEquals(ValidationIssueType.NAME, response.getIssue().getType()); + assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); + } + + @Test + public void testValidateAnomalyDetectorWithNonExistingFeatureField() throws IOException { + Feature maxFeature = maxValueFeature(nameField, "non_existing_field", nameField); + AnomalyDetector anomalyDetector = TestHelpers.randomAnomalyDetector(timeField, "test-index", ImmutableList.of(maxFeature)); + ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); + ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + anomalyDetector, + ValidationAspect.DETECTOR.getName(), + 5, + 5, + 5, + new TimeValue(5_000L) + ); + ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + assertNotNull(response.getIssue()); + assertEquals(ValidationIssueType.FEATURE_ATTRIBUTES, response.getIssue().getType()); + assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); + assertTrue(response.getIssue().getMessage().contains(CommonMessages.FEATURE_WITH_EMPTY_DATA_MSG)); + assertTrue(response.getIssue().getSubIssues().containsKey(maxFeature.getName())); + assertTrue(CommonMessages.FEATURE_WITH_EMPTY_DATA_MSG.contains(response.getIssue().getSubIssues().get(maxFeature.getName()))); + } + + @Test + public void testValidateAnomalyDetectorWithDuplicateFeatureAggregationNames() throws IOException { + Feature maxFeature = maxValueFeature(nameField, categoryField, "test-1"); + Feature maxFeatureTwo = maxValueFeature(nameField, categoryField, "test-2"); + AnomalyDetector anomalyDetector = TestHelpers + .randomAnomalyDetector(timeField, "test-index", ImmutableList.of(maxFeature, maxFeatureTwo)); + ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); + ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + anomalyDetector, + ValidationAspect.DETECTOR.getName(), + 5, + 5, + 5, + new TimeValue(5_000L) + ); + ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + assertNotNull(response.getIssue()); + assertTrue(response.getIssue().getMessage().contains("Config has duplicate feature aggregation query names:")); + assertEquals(ValidationIssueType.FEATURE_ATTRIBUTES, response.getIssue().getType()); + assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); + } + + @Test + public void testValidateAnomalyDetectorWithDuplicateFeatureNamesAndDuplicateAggregationNames() throws IOException { + Feature maxFeature = maxValueFeature(nameField, categoryField, nameField); + Feature maxFeatureTwo = maxValueFeature(nameField, categoryField, nameField); + AnomalyDetector anomalyDetector = TestHelpers + .randomAnomalyDetector(timeField, "test-index", ImmutableList.of(maxFeature, maxFeatureTwo)); + ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); + ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + anomalyDetector, + ValidationAspect.DETECTOR.getName(), + 5, + 5, + 5, + new TimeValue(5_000L) + ); + ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + assertNotNull(response.getIssue()); + assertTrue(response.getIssue().getMessage().contains("Config has duplicate feature aggregation query names:")); + assertTrue(response.getIssue().getMessage().contains("There are duplicate feature names:")); + assertEquals(ValidationIssueType.FEATURE_ATTRIBUTES, response.getIssue().getType()); + assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); + } + + @Test + public void testValidateAnomalyDetectorWithDuplicateFeatureNames() throws IOException { + Feature maxFeature = maxValueFeature(nameField, categoryField, nameField); + Feature maxFeatureTwo = maxValueFeature("test_1", categoryField, nameField); + AnomalyDetector anomalyDetector = TestHelpers + .randomAnomalyDetector(timeField, "test-index", ImmutableList.of(maxFeature, maxFeatureTwo)); + ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); + ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + anomalyDetector, + ValidationAspect.DETECTOR.getName(), + 5, + 5, + 5, + new TimeValue(5_000L) + ); + ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + assertNotNull(response.getIssue()); + assertTrue( + "actual: " + response.getIssue().getMessage(), + response.getIssue().getMessage().contains("There are duplicate feature names:") + ); + assertEquals(ValidationIssueType.FEATURE_ATTRIBUTES, response.getIssue().getType()); + assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); + } + + @Test + public void testValidateAnomalyDetectorWithInvalidFeatureField() throws IOException { + Feature maxFeature = maxValueFeature(nameField, categoryField, nameField); + AnomalyDetector anomalyDetector = TestHelpers.randomAnomalyDetector(timeField, "test-index", ImmutableList.of(maxFeature)); + ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); + ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + anomalyDetector, + ValidationAspect.DETECTOR.getName(), + 5, + 5, + 5, + new TimeValue(5_000L) + ); + ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + assertNotNull(response.getIssue()); + assertEquals(ValidationIssueType.FEATURE_ATTRIBUTES, response.getIssue().getType()); + assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); + assertTrue(response.getIssue().getMessage().contains(CommonMessages.FEATURE_WITH_INVALID_QUERY_MSG)); + assertTrue(response.getIssue().getSubIssues().containsKey(maxFeature.getName())); + assertTrue(CommonMessages.FEATURE_WITH_INVALID_QUERY_MSG.contains(response.getIssue().getSubIssues().get(maxFeature.getName()))); + } + + @Test + public void testValidateAnomalyDetectorWithUnknownFeatureField() throws IOException { + AggregationBuilder aggregationBuilder = TestHelpers.parseAggregation("{\"test\":{\"terms\":{\"field\":\"type\"}}}"); + AnomalyDetector anomalyDetector = TestHelpers + .randomAnomalyDetector( + timeField, + "test-index", + ImmutableList.of(new Feature(randomAlphaOfLength(5), nameField, true, aggregationBuilder)) + ); + ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); + ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + anomalyDetector, + ValidationAspect.DETECTOR.getName(), + 5, + 5, + 5, + new TimeValue(5_000L) + ); + ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + assertNotNull(response.getIssue()); + assertEquals(ValidationIssueType.FEATURE_ATTRIBUTES, response.getIssue().getType()); + assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); + assertTrue(response.getIssue().getMessage().contains(CommonMessages.UNKNOWN_SEARCH_QUERY_EXCEPTION_MSG)); + assertTrue(response.getIssue().getSubIssues().containsKey(nameField)); + } + + @Test + public void testValidateAnomalyDetectorWithMultipleInvalidFeatureField() throws IOException { + Feature maxFeature = maxValueFeature(nameField, categoryField, nameField); + Feature maxFeatureTwo = maxValueFeature("test_two", categoryField, "test_two"); + AnomalyDetector anomalyDetector = TestHelpers + .randomAnomalyDetector(timeField, "test-index", ImmutableList.of(maxFeature, maxFeatureTwo)); + ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); + ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + anomalyDetector, + ValidationAspect.DETECTOR.getName(), + 5, + 5, + 5, + new TimeValue(5_000L) + ); + ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + assertNotNull(response.getIssue()); + assertEquals(response.getIssue().getSubIssues().keySet().size(), 2); + assertEquals(ValidationIssueType.FEATURE_ATTRIBUTES, response.getIssue().getType()); + assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); + assertTrue(response.getIssue().getMessage().contains(CommonMessages.FEATURE_WITH_INVALID_QUERY_MSG)); + assertTrue(response.getIssue().getSubIssues().containsKey(maxFeature.getName())); + assertTrue(CommonMessages.FEATURE_WITH_INVALID_QUERY_MSG.contains(response.getIssue().getSubIssues().get(maxFeature.getName()))); + } + + @Test + public void testValidateAnomalyDetectorWithCustomResultIndex() throws IOException { + String resultIndex = ADCommonName.CUSTOM_RESULT_INDEX_PREFIX + "test"; + createCustomADResultIndex(resultIndex); + AnomalyDetector anomalyDetector = TestHelpers + .randomDetector( + ImmutableList.of(TestHelpers.randomFeature()), + randomAlphaOfLength(5).toLowerCase(Locale.ROOT), + randomIntBetween(1, 5), + timeField, + null, + resultIndex + ); + ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); + ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + anomalyDetector, + ValidationAspect.DETECTOR.getName(), + 5, + 5, + 5, + new TimeValue(5_000L) + ); + ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + assertNull(response.getIssue()); + } + + @Test + public void testValidateAnomalyDetectorWithCustomResultIndexCreated() throws IOException { + testValidateAnomalyDetectorWithCustomResultIndex(true); + } + + @Test + public void testValidateAnomalyDetectorWithCustomResultIndexPresentButNotCreated() throws IOException { + testValidateAnomalyDetectorWithCustomResultIndex(false); + + } + + @Test + public void testValidateAnomalyDetectorWithCustomResultIndexWithInvalidMapping() throws IOException { + String resultIndex = ADCommonName.CUSTOM_RESULT_INDEX_PREFIX + "test"; + URL url = ADIndexManagement.class.getClassLoader().getResource("mappings/anomaly-checkpoint.json"); + createIndex(resultIndex, Resources.toString(url, Charsets.UTF_8)); + AnomalyDetector anomalyDetector = TestHelpers + .randomDetector( + ImmutableList.of(TestHelpers.randomFeature()), + randomAlphaOfLength(5).toLowerCase(Locale.ROOT), + randomIntBetween(1, 5), + timeField, + null, + resultIndex + ); + ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); + ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + anomalyDetector, + ValidationAspect.DETECTOR.getName(), + 5, + 5, + 5, + new TimeValue(5_000L) + ); + ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + assertEquals(ValidationIssueType.RESULT_INDEX, response.getIssue().getType()); + assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); + assertTrue(response.getIssue().getMessage().contains(CommonMessages.INVALID_RESULT_INDEX_MAPPING)); + } + + private void testValidateAnomalyDetectorWithCustomResultIndex(boolean resultIndexCreated) throws IOException { + String resultIndex = ADCommonName.CUSTOM_RESULT_INDEX_PREFIX + "test"; + if (resultIndexCreated) { + createCustomADResultIndex(resultIndex); + } + AnomalyDetector anomalyDetector = TestHelpers + .randomDetector( + ImmutableList.of(TestHelpers.randomFeature()), + randomAlphaOfLength(5).toLowerCase(Locale.ROOT), + randomIntBetween(1, 5), + timeField, + null, + resultIndex + ); + ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); + ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + anomalyDetector, + ValidationAspect.DETECTOR.getName(), + 5, + 5, + 5, + new TimeValue(5_000L) + ); + ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + assertNull(response.getIssue()); + } + + @Test + public void testValidateAnomalyDetectorWithInvalidDetectorName() throws IOException { + AnomalyDetector anomalyDetector = new AnomalyDetector( + randomAlphaOfLength(5), + randomLong(), + "#$32", + randomAlphaOfLength(5), + timeField, + ImmutableList.of(randomAlphaOfLength(5).toLowerCase(Locale.ROOT)), + ImmutableList.of(TestHelpers.randomFeature()), + TestHelpers.randomQuery(), + TestHelpers.randomIntervalTimeConfiguration(), + TestHelpers.randomIntervalTimeConfiguration(), + TimeSeriesSettings.DEFAULT_SHINGLE_SIZE, + null, + 1, + Instant.now(), + null, + TestHelpers.randomUser(), + null, + TestHelpers.randomImputationOption() + ); + ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); + ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + anomalyDetector, + ValidationAspect.DETECTOR.getName(), + 5, + 5, + 5, + new TimeValue(5_000L) + ); + ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + assertEquals(ValidationIssueType.NAME, response.getIssue().getType()); + assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); + assertEquals(CommonMessages.INVALID_NAME, response.getIssue().getMessage()); + } + + @Test + public void testValidateAnomalyDetectorWithDetectorNameTooLong() throws IOException { + AnomalyDetector anomalyDetector = new AnomalyDetector( + randomAlphaOfLength(5), + randomLong(), + "abababababababababababababababababababababababababababababababababababababababababababababababab", + randomAlphaOfLength(5), + timeField, + ImmutableList.of(randomAlphaOfLength(5).toLowerCase(Locale.ROOT)), + ImmutableList.of(TestHelpers.randomFeature()), + TestHelpers.randomQuery(), + TestHelpers.randomIntervalTimeConfiguration(), + TestHelpers.randomIntervalTimeConfiguration(), + TimeSeriesSettings.DEFAULT_SHINGLE_SIZE, + null, + 1, + Instant.now(), + null, + TestHelpers.randomUser(), + null, + TestHelpers.randomImputationOption() + ); + ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); + ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + anomalyDetector, + ValidationAspect.DETECTOR.getName(), + 5, + 5, + 5, + new TimeValue(5_000L) + ); + ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + assertEquals(ValidationIssueType.NAME, response.getIssue().getType()); + assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); + assertTrue(response.getIssue().getMessage().contains("Name should be shortened. The maximum limit is")); + } + + @Test + public void testValidateAnomalyDetectorWithNonExistentTimefield() throws IOException { + AnomalyDetector anomalyDetector = TestHelpers.randomAnomalyDetector(ImmutableMap.of(), Instant.now()); + ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); + ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + anomalyDetector, + ValidationAspect.DETECTOR.getName(), + 5, + 5, + 5, + new TimeValue(5_000L) + ); + ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + assertEquals(ValidationIssueType.TIMEFIELD_FIELD, response.getIssue().getType()); + assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); + assertEquals( + String.format(Locale.ROOT, CommonMessages.NON_EXISTENT_TIMESTAMP, anomalyDetector.getTimeField()), + response.getIssue().getMessage() + ); + } + + @Test + public void testValidateAnomalyDetectorWithNonDateTimeField() throws IOException { + AnomalyDetector anomalyDetector = TestHelpers.randomAnomalyDetector(categoryField, "index-test"); + ingestTestDataValidate(anomalyDetector.getIndices().get(0), Instant.now().minus(1, ChronoUnit.DAYS), 1, "error"); + ValidateAnomalyDetectorRequest request = new ValidateAnomalyDetectorRequest( + anomalyDetector, + ValidationAspect.DETECTOR.getName(), + 5, + 5, + 5, + new TimeValue(5_000L) + ); + ValidateAnomalyDetectorResponse response = client().execute(ValidateAnomalyDetectorAction.INSTANCE, request).actionGet(5_000); + assertEquals(ValidationIssueType.TIMEFIELD_FIELD, response.getIssue().getType()); + assertEquals(ValidationAspect.DETECTOR, response.getIssue().getAspect()); + assertEquals( + String.format(Locale.ROOT, CommonMessages.INVALID_TIMESTAMP, anomalyDetector.getTimeField()), + response.getIssue().getMessage() + ); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/handler/ADSearchHandlerTests.java-e b/src/test/java/org/opensearch/ad/transport/handler/ADSearchHandlerTests.java-e new file mode 100644 index 000000000..793bdc7b9 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/handler/ADSearchHandlerTests.java-e @@ -0,0 +1,93 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport.handler; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES; +import static org.opensearch.timeseries.TestHelpers.matchAllRequest; + +import org.junit.Before; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.ADUnitTestCase; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.threadpool.ThreadPool; + +public class ADSearchHandlerTests extends ADUnitTestCase { + + private Client client; + private Settings settings; + private ClusterService clusterService; + private ADSearchHandler searchHandler; + private ClusterSettings clusterSettings; + + private SearchRequest request; + + private ActionListener listener; + + @SuppressWarnings("unchecked") + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + settings = Settings.builder().put(FILTER_BY_BACKEND_ROLES.getKey(), false).build(); + clusterSettings = clusterSetting(settings, FILTER_BY_BACKEND_ROLES); + clusterService = new ClusterService(settings, clusterSettings, null); + client = mock(Client.class); + searchHandler = new ADSearchHandler(settings, clusterService, client); + + ThreadContext threadContext = new ThreadContext(settings); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alice|odfe,aes|engineering,operations"); + org.opensearch.threadpool.ThreadPool mockThreadPool = mock(ThreadPool.class); + when(client.threadPool()).thenReturn(mockThreadPool); + when(client.threadPool().getThreadContext()).thenReturn(threadContext); + when(mockThreadPool.getThreadContext()).thenReturn(threadContext); + + request = mock(SearchRequest.class); + listener = mock(ActionListener.class); + } + + public void testSearchException() { + doThrow(new RuntimeException("test")).when(client).search(any(), any()); + searchHandler.search(request, listener); + verify(listener, times(1)).onFailure(any()); + } + + public void testFilterEnabledWithWrongSearch() { + settings = Settings.builder().put(FILTER_BY_BACKEND_ROLES.getKey(), true).build(); + clusterService = new ClusterService(settings, clusterSettings, null); + + searchHandler = new ADSearchHandler(settings, clusterService, client); + searchHandler.search(request, listener); + verify(listener, times(1)).onFailure(any()); + } + + public void testFilterEnabled() { + settings = Settings.builder().put(FILTER_BY_BACKEND_ROLES.getKey(), true).build(); + clusterService = new ClusterService(settings, clusterSettings, null); + + searchHandler = new ADSearchHandler(settings, clusterService, client); + searchHandler.search(matchAllRequest(), listener); + verify(client, times(1)).search(any(), any()); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/handler/AbstractIndexHandlerTest.java-e b/src/test/java/org/opensearch/ad/transport/handler/AbstractIndexHandlerTest.java-e new file mode 100644 index 000000000..4d1c1ed44 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/handler/AbstractIndexHandlerTest.java-e @@ -0,0 +1,147 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport.handler; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.when; +import static org.opensearch.timeseries.TestHelpers.createIndexBlockedState; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Locale; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.transport.AnomalyResultTests; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.ad.util.IndexUtils; +import org.opensearch.ad.util.Throttler; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; + +public abstract class AbstractIndexHandlerTest extends AbstractTimeSeriesTest { + enum IndexCreation { + RUNTIME_EXCEPTION, + RESOURCE_EXISTS_EXCEPTION, + ACKED, + NOT_ACKED + } + + protected static Settings settings; + protected ClientUtil clientUtil; + protected ThreadPool context; + protected IndexUtils indexUtil; + protected String detectorId = "123"; + + @Mock + protected Client client; + + @Mock + protected ADIndexManagement anomalyDetectionIndices; + + @Mock + protected Throttler throttler; + + @Mock + protected ClusterService clusterService; + + @Mock + protected IndexNameExpressionResolver indexNameResolver; + + @BeforeClass + public static void setUpBeforeClass() { + setUpThreadPool(AnomalyResultTests.class.getSimpleName()); + settings = Settings + .builder() + .put("plugins.anomaly_detection.max_retry_for_backoff", 2) + .put("plugins.anomaly_detection.backoff_initial_delay", TimeValue.timeValueMillis(1)) + .build(); + } + + @AfterClass + public static void tearDownAfterClass() { + tearDownThreadPool(); + settings = null; + } + + @Override + public void setUp() throws Exception { + super.setUp(); + MockitoAnnotations.initMocks(this); + setWriteBlockAdResultIndex(false); + context = TestHelpers.createThreadPool(); + clientUtil = new ClientUtil(settings, client, throttler, context); + indexUtil = new IndexUtils(client, clientUtil, clusterService, indexNameResolver); + } + + protected void setWriteBlockAdResultIndex(boolean blocked) { + String indexName = randomAlphaOfLength(10); + Settings settings = blocked + ? Settings.builder().put(IndexMetadata.INDEX_BLOCKS_WRITE_SETTING.getKey(), true).build() + : Settings.EMPTY; + ClusterState blockedClusterState = createIndexBlockedState(indexName, settings, ADCommonName.ANOMALY_RESULT_INDEX_ALIAS); + when(clusterService.state()).thenReturn(blockedClusterState); + when(indexNameResolver.concreteIndexNames(any(), any(), any(String.class))).thenReturn(new String[] { indexName }); + } + + @SuppressWarnings("unchecked") + protected void setUpSavingAnomalyResultIndex(boolean anomalyResultIndexExists, IndexCreation creationResult) throws IOException { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length >= 1 + ); + ActionListener listener = invocation.getArgument(0); + assertTrue(listener != null); + switch (creationResult) { + case RUNTIME_EXCEPTION: + listener.onFailure(new RuntimeException()); + break; + case RESOURCE_EXISTS_EXCEPTION: + listener.onFailure(new ResourceAlreadyExistsException(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS)); + break; + case ACKED: + listener.onResponse(new CreateIndexResponse(true, true, ADCommonName.ANOMALY_RESULT_INDEX_ALIAS)); + break; + case NOT_ACKED: + listener.onResponse(new CreateIndexResponse(false, false, ADCommonName.ANOMALY_RESULT_INDEX_ALIAS)); + break; + default: + assertTrue("should not reach here", false); + break; + } + return null; + }).when(anomalyDetectionIndices).initDefaultResultIndexDirectly(any()); + when(anomalyDetectionIndices.doesDefaultResultIndexExist()).thenReturn(anomalyResultIndexExists); + } + + protected void setUpSavingAnomalyResultIndex(boolean anomalyResultIndexExists) throws IOException { + setUpSavingAnomalyResultIndex(anomalyResultIndexExists, IndexCreation.ACKED); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java index 771088893..a2635ed8f 100644 --- a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java +++ b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java @@ -43,9 +43,9 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; -import org.opensearch.index.Index; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.index.engine.VersionConflictEngineException; -import org.opensearch.index.shard.ShardId; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.TestHelpers; diff --git a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java-e b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java-e new file mode 100644 index 000000000..726a3f251 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java-e @@ -0,0 +1,218 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport.handler; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ad.constant.ADCommonName.ANOMALY_RESULT_INDEX_ALIAS; + +import java.io.IOException; +import java.time.Clock; +import java.util.Optional; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.bulk.BulkAction; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequestBuilder; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.ad.ADUnitTestCase; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.util.ClientUtil; +import org.opensearch.ad.util.IndexUtils; +import org.opensearch.ad.util.Throttler; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.index.Index; +import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TestHelpers; + +import com.google.common.collect.ImmutableList; + +public class AnomalyResultBulkIndexHandlerTests extends ADUnitTestCase { + + private AnomalyResultBulkIndexHandler bulkIndexHandler; + private Client client; + private IndexUtils indexUtils; + private ActionListener listener; + private ADIndexManagement anomalyDetectionIndices; + + @Override + public void setUp() throws Exception { + super.setUp(); + anomalyDetectionIndices = mock(ADIndexManagement.class); + client = mock(Client.class); + Settings settings = Settings.EMPTY; + Clock clock = mock(Clock.class); + Throttler throttler = new Throttler(clock); + ThreadPool threadpool = mock(ThreadPool.class); + ClientUtil clientUtil = new ClientUtil(Settings.EMPTY, client, throttler, threadpool); + indexUtils = mock(IndexUtils.class); + ClusterService clusterService = mock(ClusterService.class); + ThreadPool threadPool = mock(ThreadPool.class); + bulkIndexHandler = new AnomalyResultBulkIndexHandler( + client, + settings, + threadPool, + clientUtil, + indexUtils, + clusterService, + anomalyDetectionIndices + ); + listener = spy(new ActionListener() { + @Override + public void onResponse(BulkResponse bulkItemResponses) {} + + @Override + public void onFailure(Exception e) {} + }); + } + + public void testNullAnomalyResults() { + bulkIndexHandler.bulkIndexAnomalyResult(null, null, listener); + verify(listener, times(1)).onResponse(null); + verify(anomalyDetectionIndices, never()).doesConfigIndexExist(); + } + + public void testAnomalyResultBulkIndexHandler_IndexNotExist() { + when(anomalyDetectionIndices.doesIndexExist("testIndex")).thenReturn(false); + AnomalyResult anomalyResult = mock(AnomalyResult.class); + when(anomalyResult.getConfigId()).thenReturn("testId"); + + bulkIndexHandler.bulkIndexAnomalyResult("testIndex", ImmutableList.of(anomalyResult), listener); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Can't find result index testIndex", exceptionCaptor.getValue().getMessage()); + } + + public void testAnomalyResultBulkIndexHandler_InValidResultIndexMapping() { + when(anomalyDetectionIndices.doesIndexExist("testIndex")).thenReturn(true); + when(anomalyDetectionIndices.isValidResultIndexMapping("testIndex")).thenReturn(false); + AnomalyResult anomalyResult = mock(AnomalyResult.class); + when(anomalyResult.getConfigId()).thenReturn("testId"); + + bulkIndexHandler.bulkIndexAnomalyResult("testIndex", ImmutableList.of(anomalyResult), listener); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("wrong index mapping of custom AD result index", exceptionCaptor.getValue().getMessage()); + } + + public void testAnomalyResultBulkIndexHandler_FailBulkIndexAnomaly() throws IOException { + when(anomalyDetectionIndices.doesIndexExist("testIndex")).thenReturn(true); + when(anomalyDetectionIndices.isValidResultIndexMapping("testIndex")).thenReturn(true); + AnomalyResult anomalyResult = mock(AnomalyResult.class); + when(anomalyResult.getConfigId()).thenReturn("testId"); + when(anomalyResult.toXContent(any(), any())).thenThrow(new RuntimeException()); + + bulkIndexHandler.bulkIndexAnomalyResult("testIndex", ImmutableList.of(anomalyResult), listener); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Failed to prepare request to bulk index anomaly results", exceptionCaptor.getValue().getMessage()); + } + + public void testCreateADResultIndexNotAcknowledged() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(new CreateIndexResponse(false, false, ANOMALY_RESULT_INDEX_ALIAS)); + return null; + }).when(anomalyDetectionIndices).initDefaultResultIndexDirectly(any()); + bulkIndexHandler.bulkIndexAnomalyResult(null, ImmutableList.of(mock(AnomalyResult.class)), listener); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Creating anomaly result index with mappings call not acknowledged", exceptionCaptor.getValue().getMessage()); + } + + public void testWrongAnomalyResult() { + BulkRequestBuilder bulkRequestBuilder = new BulkRequestBuilder(client, BulkAction.INSTANCE); + doReturn(bulkRequestBuilder).when(client).prepareBulk(); + doReturn(true).when(anomalyDetectionIndices).doesDefaultResultIndexExist(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + BulkItemResponse[] bulkItemResponses = new BulkItemResponse[2]; + String indexName = ANOMALY_RESULT_INDEX_ALIAS; + String type = "_doc"; + String idPrefix = "id"; + String uuid = "uuid"; + int shardIntId = 0; + ShardId shardId = new ShardId(new Index(indexName, uuid), shardIntId); + BulkItemResponse.Failure failure = new BulkItemResponse.Failure( + ANOMALY_RESULT_INDEX_ALIAS, + randomAlphaOfLength(5), + new VersionConflictEngineException(new ShardId(ANOMALY_RESULT_INDEX_ALIAS, "", 1), "id", "test") + ); + bulkItemResponses[0] = new BulkItemResponse(0, randomFrom(DocWriteRequest.OpType.values()), failure); + bulkItemResponses[1] = new BulkItemResponse( + 1, + randomFrom(DocWriteRequest.OpType.values()), + new IndexResponse(shardId, idPrefix + 1, 1, 1, randomInt(), true) + ); + BulkResponse bulkResponse = new BulkResponse(bulkItemResponses, 10); + listener.onResponse(bulkResponse); + return null; + }).when(client).bulk(any(), any()); + bulkIndexHandler + .bulkIndexAnomalyResult(null, ImmutableList.of(wrongAnomalyResult(), TestHelpers.randomAnomalyDetectResult()), listener); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue().getMessage().contains("VersionConflictEngineException")); + } + + public void testBulkSaveException() { + BulkRequestBuilder bulkRequestBuilder = mock(BulkRequestBuilder.class); + doReturn(bulkRequestBuilder).when(client).prepareBulk(); + doReturn(true).when(anomalyDetectionIndices).doesDefaultResultIndexExist(); + + String testError = randomAlphaOfLength(5); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException(testError)); + return null; + }).when(client).bulk(any(), any()); + + bulkIndexHandler.bulkIndexAnomalyResult(null, ImmutableList.of(TestHelpers.randomAnomalyDetectResult()), listener); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals(testError, exceptionCaptor.getValue().getMessage()); + } + + private AnomalyResult wrongAnomalyResult() { + return new AnomalyResult( + randomAlphaOfLength(5), + null, + randomDouble(), + randomDouble(), + randomDouble(), + null, + null, + null, + null, + null, + randomAlphaOfLength(5), + Optional.empty(), + null, + null, + null, + null, + null, + null, + null, + randomDoubleBetween(1.1, 10.0, true) + ); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultHandlerTests.java-e b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultHandlerTests.java-e new file mode 100644 index 000000000..89367a72b --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultHandlerTests.java-e @@ -0,0 +1,231 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport.handler; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.time.Clock; +import java.util.Arrays; +import java.util.Locale; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentMatchers; +import org.mockito.Mock; +import org.opensearch.action.ActionListener; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.common.exception.TimeSeriesException; + +public class AnomalyResultHandlerTests extends AbstractIndexHandlerTest { + @Mock + private NodeStateManager nodeStateManager; + + @Mock + private Clock clock; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + super.setUpLog4jForJUnit(AnomalyIndexHandler.class); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + super.tearDownLog4jForJUnit(); + } + + @Rule + public ExpectedException expectedEx = ExpectedException.none(); + + @Test + public void testSavingAdResult() throws IOException { + setUpSavingAnomalyResultIndex(false); + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length >= 2 + ); + IndexRequest request = invocation.getArgument(0); + ActionListener listener = invocation.getArgument(1); + assertTrue(request != null && listener != null); + listener.onResponse(mock(IndexResponse.class)); + return null; + }).when(client).index(any(IndexRequest.class), ArgumentMatchers.>any()); + AnomalyIndexHandler handler = new AnomalyIndexHandler( + client, + settings, + threadPool, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + clientUtil, + indexUtil, + clusterService + ); + handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, null); + assertEquals(1, testAppender.countMessage(AnomalyIndexHandler.SUCCESS_SAVING_MSG, true)); + } + + @Test + public void testSavingFailureNotRetry() throws InterruptedException, IOException { + savingFailureTemplate(false, 1, true); + + assertEquals(1, testAppender.countMessage(AnomalyIndexHandler.FAIL_TO_SAVE_ERR_MSG, true)); + assertTrue(!testAppender.containsMessage(AnomalyIndexHandler.SUCCESS_SAVING_MSG, true)); + assertTrue(!testAppender.containsMessage(AnomalyIndexHandler.RETRY_SAVING_ERR_MSG, true)); + } + + @Test + public void testSavingFailureRetry() throws InterruptedException, IOException { + setWriteBlockAdResultIndex(false); + savingFailureTemplate(true, 3, true); + + assertEquals(2, testAppender.countMessage(AnomalyIndexHandler.RETRY_SAVING_ERR_MSG, true)); + assertEquals(1, testAppender.countMessage(AnomalyIndexHandler.FAIL_TO_SAVE_ERR_MSG, true)); + assertTrue(!testAppender.containsMessage(AnomalyIndexHandler.SUCCESS_SAVING_MSG, true)); + } + + @Test + public void testIndexWriteBlock() { + setWriteBlockAdResultIndex(true); + AnomalyIndexHandler handler = new AnomalyIndexHandler( + client, + settings, + threadPool, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + clientUtil, + indexUtil, + clusterService + ); + handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, null); + + assertTrue(testAppender.containsMessage(AnomalyIndexHandler.CANNOT_SAVE_ERR_MSG, true)); + } + + @Test + public void testAdResultIndexExist() throws IOException { + setUpSavingAnomalyResultIndex(false, IndexCreation.RESOURCE_EXISTS_EXCEPTION); + AnomalyIndexHandler handler = new AnomalyIndexHandler( + client, + settings, + threadPool, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + clientUtil, + indexUtil, + clusterService + ); + handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, null); + verify(client, times(1)).index(any(), any()); + } + + @Test + public void testAdResultIndexOtherException() throws IOException { + expectedEx.expect(TimeSeriesException.class); + expectedEx.expectMessage("Error in saving .opendistro-anomaly-results for detector " + detectorId); + + setUpSavingAnomalyResultIndex(false, IndexCreation.RUNTIME_EXCEPTION); + AnomalyIndexHandler handler = new AnomalyIndexHandler( + client, + settings, + threadPool, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + clientUtil, + indexUtil, + clusterService + ); + handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, null); + verify(client, never()).index(any(), any()); + } + + /** + * Template to test exponential backoff retry during saving anomaly result. + * + * @param throwOpenSearchRejectedExecutionException whether to throw + * OpenSearchRejectedExecutionException in the + * client::index mock or not + * @param latchCount used for coordinating. Equal to + * number of expected retries plus 1. + * @throws InterruptedException if thread execution is interrupted + * @throws IOException if IO failures + */ + @SuppressWarnings("unchecked") + private void savingFailureTemplate(boolean throwOpenSearchRejectedExecutionException, int latchCount, boolean adResultIndexExists) + throws InterruptedException, + IOException { + setUpSavingAnomalyResultIndex(adResultIndexExists); + + final CountDownLatch backoffLatch = new CountDownLatch(latchCount); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + assertTrue( + String.format(Locale.ROOT, "The size of args is %d. Its content is %s", args.length, Arrays.toString(args)), + args.length >= 2 + ); + IndexRequest request = invocation.getArgument(0); + ActionListener listener = invocation.getArgument(1); + assertTrue(request != null && listener != null); + if (throwOpenSearchRejectedExecutionException) { + listener.onFailure(new OpenSearchRejectedExecutionException("")); + } else { + listener.onFailure(new IllegalArgumentException()); + } + + backoffLatch.countDown(); + return null; + }).when(client).index(any(IndexRequest.class), ArgumentMatchers.>any()); + + Settings backoffSettings = Settings + .builder() + .put("plugins.anomaly_detection.max_retry_for_backoff", 2) + .put("plugins.anomaly_detection.backoff_initial_delay", TimeValue.timeValueMillis(1)) + .build(); + + AnomalyIndexHandler handler = new AnomalyIndexHandler( + client, + backoffSettings, + threadPool, + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + clientUtil, + indexUtil, + clusterService + ); + + handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, null); + + backoffLatch.await(1, TimeUnit.MINUTES); + } +} diff --git a/src/test/java/org/opensearch/ad/transport/handler/MultiEntityResultHandlerTests.java-e b/src/test/java/org/opensearch/ad/transport/handler/MultiEntityResultHandlerTests.java-e new file mode 100644 index 000000000..4c8446577 --- /dev/null +++ b/src/test/java/org/opensearch/ad/transport/handler/MultiEntityResultHandlerTests.java-e @@ -0,0 +1,205 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport.handler; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.junit.Test; +import org.mockito.ArgumentMatchers; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.ratelimit.RequestPriority; +import org.opensearch.ad.ratelimit.ResultWriteRequest; +import org.opensearch.ad.transport.ADResultBulkAction; +import org.opensearch.ad.transport.ADResultBulkRequest; +import org.opensearch.ad.transport.ADResultBulkResponse; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.common.exception.TimeSeriesException; + +public class MultiEntityResultHandlerTests extends AbstractIndexHandlerTest { + private MultiEntityResultHandler handler; + private ADResultBulkRequest request; + private ADResultBulkResponse response; + + @Override + public void setUp() throws Exception { + super.setUp(); + + handler = new MultiEntityResultHandler( + client, + settings, + threadPool, + anomalyDetectionIndices, + clientUtil, + indexUtil, + clusterService + ); + + request = new ADResultBulkRequest(); + ResultWriteRequest resultWriteRequest = new ResultWriteRequest( + Instant.now().plus(10, ChronoUnit.MINUTES).toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + TestHelpers.randomAnomalyDetectResult(), + null + ); + request.add(resultWriteRequest); + + response = new ADResultBulkResponse(); + + super.setUpLog4jForJUnit(MultiEntityResultHandler.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(response); + return null; + }).when(client).execute(eq(ADResultBulkAction.INSTANCE), any(), ArgumentMatchers.>any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + super.tearDownLog4jForJUnit(); + } + + @Test + public void testIndexWriteBlock() throws InterruptedException { + setWriteBlockAdResultIndex(true); + + CountDownLatch verified = new CountDownLatch(1); + + handler.flush(request, ActionListener.wrap(response -> { + assertTrue("Should not reach here ", false); + verified.countDown(); + }, exception -> { + assertTrue(exception instanceof TimeSeriesException); + assertTrue( + "actual: " + exception.getMessage(), + exception.getMessage().contains(MultiEntityResultHandler.CANNOT_SAVE_RESULT_ERR_MSG) + ); + verified.countDown(); + })); + + assertTrue(verified.await(100, TimeUnit.SECONDS)); + } + + @Test + public void testSavingAdResult() throws IOException, InterruptedException { + setUpSavingAnomalyResultIndex(false); + + CountDownLatch verified = new CountDownLatch(1); + handler.flush(request, ActionListener.wrap(response -> { verified.countDown(); }, exception -> { + assertTrue("Should not reach here ", false); + verified.countDown(); + })); + assertTrue(verified.await(100, TimeUnit.SECONDS)); + assertEquals(1, testAppender.countMessage(MultiEntityResultHandler.SUCCESS_SAVING_RESULT_MSG, false)); + } + + @Test + public void testSavingFailure() throws IOException, InterruptedException { + setUpSavingAnomalyResultIndex(false); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException()); + return null; + }).when(client).execute(eq(ADResultBulkAction.INSTANCE), any(), ArgumentMatchers.>any()); + + CountDownLatch verified = new CountDownLatch(1); + handler.flush(request, ActionListener.wrap(response -> { + assertTrue("Should not reach here ", false); + verified.countDown(); + }, exception -> { + assertTrue(exception instanceof RuntimeException); + verified.countDown(); + })); + assertTrue(verified.await(100, TimeUnit.SECONDS)); + } + + @Test + public void testAdResultIndexExists() throws IOException, InterruptedException { + setUpSavingAnomalyResultIndex(true); + + CountDownLatch verified = new CountDownLatch(1); + handler.flush(request, ActionListener.wrap(response -> { verified.countDown(); }, exception -> { + assertTrue("Should not reach here ", false); + verified.countDown(); + })); + assertTrue(verified.await(100, TimeUnit.SECONDS)); + assertEquals(1, testAppender.countMessage(MultiEntityResultHandler.SUCCESS_SAVING_RESULT_MSG, false)); + } + + @Test + public void testNothingToSave() throws IOException, InterruptedException { + setUpSavingAnomalyResultIndex(false); + + CountDownLatch verified = new CountDownLatch(1); + handler.flush(new ADResultBulkRequest(), ActionListener.wrap(response -> { + assertTrue("Should not reach here ", false); + verified.countDown(); + }, exception -> { + assertTrue(exception instanceof TimeSeriesException); + verified.countDown(); + })); + assertTrue(verified.await(100, TimeUnit.SECONDS)); + } + + @Test + public void testCreateUnAcked() throws IOException, InterruptedException { + setUpSavingAnomalyResultIndex(false, IndexCreation.NOT_ACKED); + + CountDownLatch verified = new CountDownLatch(1); + handler.flush(request, ActionListener.wrap(response -> { + assertTrue("Should not reach here ", false); + verified.countDown(); + }, exception -> { + assertTrue(exception instanceof TimeSeriesException); + verified.countDown(); + })); + assertTrue(verified.await(100, TimeUnit.SECONDS)); + } + + @Test + public void testCreateRuntimeException() throws IOException, InterruptedException { + setUpSavingAnomalyResultIndex(false, IndexCreation.RUNTIME_EXCEPTION); + + CountDownLatch verified = new CountDownLatch(1); + handler.flush(request, ActionListener.wrap(response -> { + assertTrue("Should not reach here ", false); + verified.countDown(); + }, exception -> { + assertTrue(exception instanceof RuntimeException); + verified.countDown(); + })); + assertTrue(verified.await(100, TimeUnit.SECONDS)); + } + + @Test + public void testCreateResourcExistsException() throws IOException, InterruptedException { + setUpSavingAnomalyResultIndex(false, IndexCreation.RESOURCE_EXISTS_EXCEPTION); + + CountDownLatch verified = new CountDownLatch(1); + handler.flush(request, ActionListener.wrap(response -> { verified.countDown(); }, exception -> { + assertTrue("Should not reach here ", false); + verified.countDown(); + })); + assertTrue(verified.await(100, TimeUnit.SECONDS)); + assertEquals(1, testAppender.countMessage(MultiEntityResultHandler.SUCCESS_SAVING_RESULT_MSG, false)); + } +} diff --git a/src/test/java/org/opensearch/ad/util/ArrayEqMatcher.java-e b/src/test/java/org/opensearch/ad/util/ArrayEqMatcher.java-e new file mode 100644 index 000000000..51b5b0c26 --- /dev/null +++ b/src/test/java/org/opensearch/ad/util/ArrayEqMatcher.java-e @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import java.util.Arrays; + +import org.mockito.ArgumentMatcher; + +/** + * An argument matcher based on deep equality needed for array types. + * + * The default eq or aryEq from Mockito fails on nested array types, such as a matrix. + * This matcher takes the expected argument and returns a match result based on deep equality. + */ +public class ArrayEqMatcher implements ArgumentMatcher { + + private final T expected; + + /** + * Constructor with expected value. + * + * @param expected the value expected to match by equality + */ + public ArrayEqMatcher(T expected) { + this.expected = expected; + } + + @Override + public boolean matches(T actual) { + return Arrays.deepEquals((Object[]) expected, (Object[]) actual); + } +} diff --git a/src/test/java/org/opensearch/ad/util/BulkUtilTests.java b/src/test/java/org/opensearch/ad/util/BulkUtilTests.java index 4aea4ac55..aadc2d999 100644 --- a/src/test/java/org/opensearch/ad/util/BulkUtilTests.java +++ b/src/test/java/org/opensearch/ad/util/BulkUtilTests.java @@ -21,9 +21,9 @@ import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.common.xcontent.XContentType; -import org.opensearch.index.Index; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.index.engine.VersionConflictEngineException; -import org.opensearch.index.shard.ShardId; import org.opensearch.test.OpenSearchTestCase; public class BulkUtilTests extends OpenSearchTestCase { diff --git a/src/test/java/org/opensearch/ad/util/BulkUtilTests.java-e b/src/test/java/org/opensearch/ad/util/BulkUtilTests.java-e new file mode 100644 index 000000000..3ce2dbe61 --- /dev/null +++ b/src/test/java/org/opensearch/ad/util/BulkUtilTests.java-e @@ -0,0 +1,59 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import java.util.List; + +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkItemResponse.Failure; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.index.Index; +import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.test.OpenSearchTestCase; + +public class BulkUtilTests extends OpenSearchTestCase { + public void testGetFailedIndexRequest() { + BulkItemResponse[] itemResponses = new BulkItemResponse[2]; + String indexName = "index"; + String type = "_doc"; + String idPrefix = "id"; + String uuid = "uuid"; + int shardIntId = 0; + ShardId shardId = new ShardId(new Index(indexName, uuid), shardIntId); + itemResponses[0] = new BulkItemResponse( + 0, + randomFrom(DocWriteRequest.OpType.values()), + new Failure(indexName, idPrefix + 0, new VersionConflictEngineException(shardId, "", "blah")) + ); + itemResponses[1] = new BulkItemResponse( + 1, + randomFrom(DocWriteRequest.OpType.values()), + new IndexResponse(shardId, idPrefix + 1, 1, 1, randomInt(), true) + ); + BulkResponse response = new BulkResponse(itemResponses, 0); + + BulkRequest request = new BulkRequest(); + for (int i = 0; i < 2; i++) { + request.add(new IndexRequest(indexName).id(idPrefix + i).source(XContentType.JSON, "field", "value")); + } + + List retry = BulkUtil.getFailedIndexRequest(request, response); + assertEquals(1, retry.size()); + assertEquals(idPrefix + 0, retry.get(0).id()); + } +} diff --git a/src/test/java/org/opensearch/ad/util/DateUtilsTests.java-e b/src/test/java/org/opensearch/ad/util/DateUtilsTests.java-e new file mode 100644 index 000000000..593445b01 --- /dev/null +++ b/src/test/java/org/opensearch/ad/util/DateUtilsTests.java-e @@ -0,0 +1,24 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import java.time.Duration; + +import org.opensearch.common.unit.TimeValue; +import org.opensearch.test.OpenSearchTestCase; + +public class DateUtilsTests extends OpenSearchTestCase { + public void testDuration() { + TimeValue time = TimeValue.timeValueHours(3); + assertEquals(Duration.ofHours(3), DateUtils.toDuration(time)); + } +} diff --git a/src/test/java/org/opensearch/ad/util/ExceptionUtilsTests.java b/src/test/java/org/opensearch/ad/util/ExceptionUtilsTests.java index 967ecaaf9..8d64ba08e 100644 --- a/src/test/java/org/opensearch/ad/util/ExceptionUtilsTests.java +++ b/src/test/java/org/opensearch/ad/util/ExceptionUtilsTests.java @@ -14,8 +14,8 @@ import org.opensearch.OpenSearchException; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.replication.ReplicationResponse; -import org.opensearch.index.shard.ShardId; -import org.opensearch.rest.RestStatus; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.timeseries.common.exception.TimeSeriesException; diff --git a/src/test/java/org/opensearch/ad/util/ExceptionUtilsTests.java-e b/src/test/java/org/opensearch/ad/util/ExceptionUtilsTests.java-e new file mode 100644 index 000000000..8d64ba08e --- /dev/null +++ b/src/test/java/org/opensearch/ad/util/ExceptionUtilsTests.java-e @@ -0,0 +1,66 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import org.opensearch.OpenSearchException; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.replication.ReplicationResponse; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.common.exception.TimeSeriesException; + +public class ExceptionUtilsTests extends OpenSearchTestCase { + + public void testGetShardsFailure() { + ShardId shardId = new ShardId(randomAlphaOfLength(5), randomAlphaOfLength(5), 1); + ReplicationResponse.ShardInfo.Failure failure = new ReplicationResponse.ShardInfo.Failure( + shardId, + randomAlphaOfLength(5), + new RuntimeException("test"), + RestStatus.BAD_REQUEST, + false + ); + ReplicationResponse.ShardInfo shardInfo = new ReplicationResponse.ShardInfo(2, 1, failure); + IndexResponse indexResponse = new IndexResponse(shardId, "id", randomLong(), randomLong(), randomLong(), randomBoolean()); + indexResponse.setShardInfo(shardInfo); + String shardsFailure = ExceptionUtil.getShardsFailure(indexResponse); + assertEquals("RuntimeException[test]", shardsFailure); + } + + public void testGetShardsFailureWithoutError() { + ShardId shardId = new ShardId(randomAlphaOfLength(5), randomAlphaOfLength(5), 1); + IndexResponse indexResponse = new IndexResponse(shardId, "id", randomLong(), randomLong(), randomLong(), randomBoolean()); + assertNull(ExceptionUtil.getShardsFailure(indexResponse)); + + ReplicationResponse.ShardInfo shardInfo = new ReplicationResponse.ShardInfo(2, 1, ReplicationResponse.EMPTY); + indexResponse.setShardInfo(shardInfo); + assertNull(ExceptionUtil.getShardsFailure(indexResponse)); + } + + public void testCountInStats() { + assertTrue(ExceptionUtil.countInStats(new TimeSeriesException("test"))); + assertFalse(ExceptionUtil.countInStats(new TimeSeriesException("test").countedInStats(false))); + assertTrue(ExceptionUtil.countInStats(new RuntimeException("test"))); + } + + public void testGetErrorMessage() { + assertEquals("test", ExceptionUtil.getErrorMessage(new TimeSeriesException("test"))); + assertEquals("test", ExceptionUtil.getErrorMessage(new IllegalArgumentException("test"))); + assertEquals("OpenSearchException[test]", ExceptionUtil.getErrorMessage(new OpenSearchException("test"))); + assertTrue( + ExceptionUtil + .getErrorMessage(new RuntimeException("test")) + .contains("at org.opensearch.ad.util.ExceptionUtilsTests.testGetErrorMessage") + ); + } +} diff --git a/src/test/java/org/opensearch/ad/util/IndexUtilsTests.java-e b/src/test/java/org/opensearch/ad/util/IndexUtilsTests.java-e new file mode 100644 index 000000000..bea6abf95 --- /dev/null +++ b/src/test/java/org/opensearch/ad/util/IndexUtilsTests.java-e @@ -0,0 +1,94 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import static org.mockito.Mockito.mock; + +import java.time.Clock; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.common.settings.Settings; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.TestHelpers; + +public class IndexUtilsTests extends OpenSearchIntegTestCase { + + private ClientUtil clientUtil; + + private IndexNameExpressionResolver indexNameResolver; + + @Before + public void setup() { + Client client = client(); + Clock clock = mock(Clock.class); + Throttler throttler = new Throttler(clock); + ThreadPool context = TestHelpers.createThreadPool(); + clientUtil = new ClientUtil(Settings.EMPTY, client, throttler, context); + indexNameResolver = mock(IndexNameExpressionResolver.class); + } + + @Test + public void testGetIndexHealth_NoIndex() { + IndexUtils indexUtils = new IndexUtils(client(), clientUtil, clusterService(), indexNameResolver); + String output = indexUtils.getIndexHealthStatus("test"); + assertEquals(IndexUtils.NONEXISTENT_INDEX_STATUS, output); + } + + @Test + public void testGetIndexHealth_Index() { + String indexName = "test-2"; + createIndex(indexName); + flush(); + IndexUtils indexUtils = new IndexUtils(client(), clientUtil, clusterService(), indexNameResolver); + String status = indexUtils.getIndexHealthStatus(indexName); + assertTrue(status.equals("green") || status.equals("yellow")); + } + + @Test + public void testGetIndexHealth_Alias() { + String indexName = "test-2"; + String aliasName = "alias"; + createIndex(indexName); + flush(); + AcknowledgedResponse response = client().admin().indices().prepareAliases().addAlias(indexName, aliasName).execute().actionGet(); + assertTrue(response.isAcknowledged()); + IndexUtils indexUtils = new IndexUtils(client(), clientUtil, clusterService(), indexNameResolver); + String status = indexUtils.getIndexHealthStatus(aliasName); + assertTrue(status.equals("green") || status.equals("yellow")); + } + + @Test + public void testGetNumberOfDocumentsInIndex_NonExistentIndex() { + IndexUtils indexUtils = new IndexUtils(client(), clientUtil, clusterService(), indexNameResolver); + assertEquals((Long) 0L, indexUtils.getNumberOfDocumentsInIndex("index")); + } + + @Test + public void testGetNumberOfDocumentsInIndex_RegularIndex() { + String indexName = "test-2"; + createIndex(indexName); + flush(); + + long count = 2100; + for (int i = 0; i < count; i++) { + index(indexName, "_doc", String.valueOf(i), "{}"); + } + flushAndRefresh(indexName); + IndexUtils indexUtils = new IndexUtils(client(), clientUtil, clusterService(), indexNameResolver); + assertEquals((Long) count, indexUtils.getNumberOfDocumentsInIndex(indexName)); + } +} diff --git a/src/test/java/org/opensearch/ad/util/MultiResponsesDelegateActionListenerTests.java-e b/src/test/java/org/opensearch/ad/util/MultiResponsesDelegateActionListenerTests.java-e new file mode 100644 index 000000000..b905ce623 --- /dev/null +++ b/src/test/java/org/opensearch/ad/util/MultiResponsesDelegateActionListenerTests.java-e @@ -0,0 +1,78 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.opensearch.timeseries.TestHelpers.randomHCADAnomalyDetectResult; + +import java.util.ArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.opensearch.action.ActionListener; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.model.DetectorProfile; +import org.opensearch.ad.model.EntityAnomalyResult; +import org.opensearch.test.OpenSearchTestCase; + +public class MultiResponsesDelegateActionListenerTests extends OpenSearchTestCase { + + public void testEmptyResponse() throws InterruptedException { + final CountDownLatch inProgressLatch = new CountDownLatch(1); + ActionListener actualListener = ActionListener.wrap(response -> { + assertTrue("Should not reach here", false); + inProgressLatch.countDown(); + }, exception -> { + String exceptionMsg = exception.getMessage(); + assertTrue(exceptionMsg, exceptionMsg.contains(MultiResponsesDelegateActionListener.NO_RESPONSE)); + inProgressLatch.countDown(); + }); + + MultiResponsesDelegateActionListener multiListener = new MultiResponsesDelegateActionListener( + actualListener, + 2, + "blah", + false + ); + multiListener.onResponse(null); + multiListener.onResponse(null); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + @SuppressWarnings("unchecked") + public void testForceResponse() { + AnomalyResult anomalyResult1 = randomHCADAnomalyDetectResult(0.25, 0.25, "error"); + AnomalyResult anomalyResult2 = randomHCADAnomalyDetectResult(0.5, 0.5, "error"); + + EntityAnomalyResult entityAnomalyResult1 = new EntityAnomalyResult(new ArrayList() { + { + add(anomalyResult1); + } + }); + EntityAnomalyResult entityAnomalyResult2 = new EntityAnomalyResult(new ArrayList() { + { + add(anomalyResult2); + } + }); + + ActionListener actualListener = mock(ActionListener.class); + MultiResponsesDelegateActionListener multiListener = + new MultiResponsesDelegateActionListener(actualListener, 3, "blah", true); + multiListener.onResponse(entityAnomalyResult1); + multiListener.onResponse(entityAnomalyResult2); + multiListener.onFailure(new RuntimeException()); + entityAnomalyResult1.merge(entityAnomalyResult2); + + verify(actualListener).onResponse(entityAnomalyResult1); + } +} diff --git a/src/test/java/org/opensearch/ad/util/ParseUtilsTests.java b/src/test/java/org/opensearch/ad/util/ParseUtilsTests.java index 188d67e59..c2dd673b4 100644 --- a/src/test/java/org/opensearch/ad/util/ParseUtilsTests.java +++ b/src/test/java/org/opensearch/ad/util/ParseUtilsTests.java @@ -20,9 +20,9 @@ import java.util.List; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.common.ParsingException; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.ParsingException; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.search.aggregations.AggregationBuilder; diff --git a/src/test/java/org/opensearch/ad/util/ParseUtilsTests.java-e b/src/test/java/org/opensearch/ad/util/ParseUtilsTests.java-e new file mode 100644 index 000000000..e7cd7f0e3 --- /dev/null +++ b/src/test/java/org/opensearch/ad/util/ParseUtilsTests.java-e @@ -0,0 +1,309 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import static org.opensearch.timeseries.util.ParseUtils.addUserBackendRolesFilter; +import static org.opensearch.timeseries.util.ParseUtils.isAdmin; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.List; + +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.core.common.ParsingException; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.util.ParseUtils; + +import com.google.common.collect.ImmutableList; + +public class ParseUtilsTests extends OpenSearchTestCase { + + public void testToInstant() throws IOException { + long epochMilli = Instant.now().toEpochMilli(); + XContentBuilder builder = XContentFactory.jsonBuilder().value(epochMilli); + XContentParser parser = this.createParser(builder); + parser.nextToken(); + Instant instant = ParseUtils.toInstant(parser); + assertEquals(epochMilli, instant.toEpochMilli()); + } + + public void testToInstantWithNullToken() throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder().value((Long) null); + XContentParser parser = this.createParser(builder); + parser.nextToken(); + XContentParser.Token token = parser.currentToken(); + assertEquals(token, XContentParser.Token.VALUE_NULL); + Instant instant = ParseUtils.toInstant(parser); + assertNull(instant); + } + + public void testToInstantWithNullValue() throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder().value(randomLong()); + XContentParser parser = this.createParser(builder); + parser.nextToken(); + parser.nextToken(); + XContentParser.Token token = parser.currentToken(); + assertNull(token); + Instant instant = ParseUtils.toInstant(parser); + assertNull(instant); + } + + public void testToInstantWithNotValue() throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().nullField("test").endObject(); + XContentParser parser = this.createParser(builder); + parser.nextToken(); + Instant instant = ParseUtils.toInstant(parser); + assertNull(instant); + } + + public void testToAggregationBuilder() throws IOException { + XContentParser parser = TestHelpers.parser("{\"aa\":{\"value_count\":{\"field\":\"ok\"}}}"); + AggregationBuilder aggregationBuilder = ParseUtils.toAggregationBuilder(parser); + assertNotNull(aggregationBuilder); + assertEquals("aa", aggregationBuilder.getName()); + } + + public void testParseAggregatorsWithAggregationQueryString() throws IOException { + AggregatorFactories.Builder agg = ParseUtils + .parseAggregators("{\"aa\":{\"value_count\":{\"field\":\"ok\"}}}", TestHelpers.xContentRegistry(), "test"); + assertEquals("test", agg.getAggregatorFactories().iterator().next().getName()); + } + + public void testParseAggregatorsWithInvalidAggregationName() throws IOException { + XContentParser parser = ParseUtils.parser("{\"aa\":{\"value_count\":{\"field\":\"ok\"}}}", TestHelpers.xContentRegistry()); + Exception ex = expectThrows(ParsingException.class, () -> ParseUtils.parseAggregators(parser, 0, "#@?><:{")); + assertTrue(ex.getMessage().contains("Aggregation names must be alpha-numeric and can only contain '_' and '-'")); + } + + public void testParseAggregatorsWithTwoAggregationTypes() throws IOException { + XContentParser parser = ParseUtils + .parser("{\"test\":{\"avg\":{\"field\":\"value\"},\"sum\":{\"field\":\"value\"}}}", TestHelpers.xContentRegistry()); + Exception ex = expectThrows(ParsingException.class, () -> ParseUtils.parseAggregators(parser, 0, "test")); + assertTrue(ex.getMessage().contains("Found two aggregation type definitions in")); + } + + public void testParseAggregatorsWithNullAggregationDefinition() throws IOException { + String aggName = "test"; + XContentParser parser = ParseUtils.parser("{\"test\":{}}", TestHelpers.xContentRegistry()); + Exception ex = expectThrows(ParsingException.class, () -> ParseUtils.parseAggregators(parser, 0, aggName)); + assertTrue(ex.getMessage().contains("Missing definition for aggregation [" + aggName + "]")); + } + + public void testParseAggregatorsWithAggregationQueryStringAndNullAggName() throws IOException { + AggregatorFactories.Builder agg = ParseUtils + .parseAggregators("{\"aa\":{\"value_count\":{\"field\":\"ok\"}}}", TestHelpers.xContentRegistry(), null); + assertEquals("aa", agg.getAggregatorFactories().iterator().next().getName()); + } + + public void testGenerateInternalFeatureQuery() throws IOException { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(null, Instant.now()); + long startTime = randomLong(); + long endTime = randomLong(); + SearchSourceBuilder builder = ParseUtils.generateInternalFeatureQuery(detector, startTime, endTime, TestHelpers.xContentRegistry()); + for (Feature feature : detector.getFeatureAttributes()) { + assertTrue(builder.toString().contains(feature.getId())); + } + } + + public void testAddUserRoleFilterWithNullUser() { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + addUserBackendRolesFilter(null, searchSourceBuilder); + assertEquals("{}", searchSourceBuilder.toString()); + } + + public void testAddUserRoleFilterWithNullUserBackendRole() { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + addUserBackendRolesFilter( + new User(randomAlphaOfLength(5), null, ImmutableList.of(randomAlphaOfLength(5)), ImmutableList.of(randomAlphaOfLength(5))), + searchSourceBuilder + ); + assertEquals( + "{\"query\":{\"bool\":{\"must\":[{\"nested\":{\"query\":{\"terms\":{\"user.backend_roles.keyword\":[]," + + "\"boost\":1.0}},\"path\":\"user\",\"ignore_unmapped\":false,\"score_mode\":\"none\",\"boost\":1.0}}]," + + "\"adjust_pure_negative\":true,\"boost\":1.0}}}", + searchSourceBuilder.toString() + ); + } + + public void testAddUserRoleFilterWithEmptyUserBackendRole() { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + addUserBackendRolesFilter( + new User( + randomAlphaOfLength(5), + ImmutableList.of(), + ImmutableList.of(randomAlphaOfLength(5)), + ImmutableList.of(randomAlphaOfLength(5)) + ), + searchSourceBuilder + ); + assertEquals( + "{\"query\":{\"bool\":{\"must\":[{\"nested\":{\"query\":{\"terms\":{\"user.backend_roles.keyword\":[]," + + "\"boost\":1.0}},\"path\":\"user\",\"ignore_unmapped\":false,\"score_mode\":\"none\",\"boost\":1.0}}]," + + "\"adjust_pure_negative\":true,\"boost\":1.0}}}", + searchSourceBuilder.toString() + ); + } + + public void testAddUserRoleFilterWithNormalUserBackendRole() { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + String backendRole1 = randomAlphaOfLength(5); + String backendRole2 = randomAlphaOfLength(5); + addUserBackendRolesFilter( + new User( + randomAlphaOfLength(5), + ImmutableList.of(backendRole1, backendRole2), + ImmutableList.of(randomAlphaOfLength(5)), + ImmutableList.of(randomAlphaOfLength(5)) + ), + searchSourceBuilder + ); + assertEquals( + "{\"query\":{\"bool\":{\"must\":[{\"nested\":{\"query\":{\"terms\":{\"user.backend_roles.keyword\":" + + "[\"" + + backendRole1 + + "\",\"" + + backendRole2 + + "\"]," + + "\"boost\":1.0}},\"path\":\"user\",\"ignore_unmapped\":false,\"score_mode\":\"none\",\"boost\":1.0}}]," + + "\"adjust_pure_negative\":true,\"boost\":1.0}}}", + searchSourceBuilder.toString() + ); + } + + public void testBatchFeatureQuery() throws IOException { + String index = randomAlphaOfLength(5); + Instant now = Instant.now().truncatedTo(ChronoUnit.SECONDS); + Feature feature1 = TestHelpers.randomFeature(true); + Feature feature2 = TestHelpers.randomFeature(false); + AnomalyDetector detector = TestHelpers + .randomAnomalyDetector(ImmutableList.of(index), ImmutableList.of(feature1, feature2), null, now, 1, false, null); + + long startTime = now.minus(10, ChronoUnit.DAYS).toEpochMilli(); + long endTime = now.plus(10, ChronoUnit.DAYS).toEpochMilli(); + SearchSourceBuilder searchSourceBuilder = ParseUtils + .batchFeatureQuery(detector, null, startTime, endTime, TestHelpers.xContentRegistry()); + assertEquals( + "{\"size\":0,\"query\":{\"bool\":{\"must\":[{\"range\":{\"" + + detector.getTimeField() + + "\":{\"from\":" + + startTime + + ",\"to\":" + + endTime + + ",\"include_lower\":true,\"include_upper\":false,\"format\":\"epoch_millis\",\"boost\"" + + ":1.0}}},{\"bool\":{\"must\":[{\"term\":{\"user\":{\"value\":\"kimchy\",\"boost\":1.0}}}],\"filter\":" + + "[{\"term\":{\"tag\":{\"value\":\"tech\",\"boost\":1.0}}}],\"must_not\":[{\"range\":{\"age\":{\"from\":10," + + "\"to\":20,\"include_lower\":true,\"include_upper\":true,\"boost\":1.0}}}],\"should\":[{\"term\":{\"tag\":" + + "{\"value\":\"wow\",\"boost\":1.0}}},{\"term\":{\"tag\":{\"value\":\"elasticsearch\",\"boost\":1.0}}}]," + + "\"adjust_pure_negative\":true,\"minimum_should_match\":\"1\",\"boost\":1.0}}],\"adjust_pure_negative" + + "\":true,\"boost\":1.0}},\"aggregations\":{\"feature_aggs\":{\"composite\":{\"size\":10000,\"sources\":" + + "[{\"date_histogram\":{\"date_histogram\":{\"field\":\"" + + detector.getTimeField() + + "\",\"missing_bucket\":false,\"order\":\"asc\"," + + "\"fixed_interval\":\"60s\"}}}]},\"aggregations\":{\"" + + feature1.getId() + + "\":{\"value_count\":{\"field\":\"ok\"}}}}}}", + searchSourceBuilder.toString() + ); + } + + public void testBatchFeatureQueryWithoutEnabledFeature() throws IOException { + String index = randomAlphaOfLength(5); + Instant now = Instant.now().truncatedTo(ChronoUnit.SECONDS); + AnomalyDetector detector = TestHelpers + .randomAnomalyDetector(ImmutableList.of(index), ImmutableList.of(TestHelpers.randomFeature(false)), null, now, 1, false, null); + + long startTime = now.minus(10, ChronoUnit.DAYS).toEpochMilli(); + long endTime = now.plus(10, ChronoUnit.DAYS).toEpochMilli(); + + TimeSeriesException exception = expectThrows( + TimeSeriesException.class, + () -> ParseUtils.batchFeatureQuery(detector, null, startTime, endTime, TestHelpers.xContentRegistry()) + ); + assertEquals("No enabled feature configured", exception.getMessage()); + } + + public void testBatchFeatureQueryWithoutFeature() throws IOException { + String index = randomAlphaOfLength(5); + Instant now = Instant.now().truncatedTo(ChronoUnit.SECONDS); + AnomalyDetector detector = TestHelpers + .randomAnomalyDetector(ImmutableList.of(index), ImmutableList.of(), null, now, 1, false, null); + + long startTime = now.minus(10, ChronoUnit.DAYS).toEpochMilli(); + long endTime = now.plus(10, ChronoUnit.DAYS).toEpochMilli(); + TimeSeriesException exception = expectThrows( + TimeSeriesException.class, + () -> ParseUtils.batchFeatureQuery(detector, null, startTime, endTime, TestHelpers.xContentRegistry()) + ); + assertEquals("No enabled feature configured", exception.getMessage()); + } + + public void testListEqualsWithoutConsideringOrder() { + assertTrue(ParseUtils.listEqualsWithoutConsideringOrder(null, null)); + assertTrue(ParseUtils.listEqualsWithoutConsideringOrder(null, ImmutableList.of())); + assertTrue(ParseUtils.listEqualsWithoutConsideringOrder(ImmutableList.of(), null)); + assertTrue(ParseUtils.listEqualsWithoutConsideringOrder(ImmutableList.of(), ImmutableList.of())); + + assertTrue(ParseUtils.listEqualsWithoutConsideringOrder(ImmutableList.of("a"), ImmutableList.of("a"))); + assertTrue(ParseUtils.listEqualsWithoutConsideringOrder(ImmutableList.of("a", "b"), ImmutableList.of("a", "b"))); + assertTrue(ParseUtils.listEqualsWithoutConsideringOrder(ImmutableList.of("b", "a"), ImmutableList.of("a", "b"))); + assertFalse(ParseUtils.listEqualsWithoutConsideringOrder(ImmutableList.of("a"), ImmutableList.of("a", "b"))); + assertFalse( + ParseUtils.listEqualsWithoutConsideringOrder(ImmutableList.of(randomAlphaOfLength(5)), ImmutableList.of(randomAlphaOfLength(5))) + ); + } + + public void testGetFeatureFieldNames() throws IOException { + Feature feature1 = TestHelpers.randomFeature("feature-name1", "field-name1", "sum", true); + Feature feature2 = TestHelpers.randomFeature("feature-name2", "field-name2", "sum", true); + Instant now = Instant.now().truncatedTo(ChronoUnit.SECONDS); + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableList.of(feature1, feature2), null, now); + List fieldNames = ParseUtils.getFeatureFieldNames(detector, TestHelpers.xContentRegistry()); + assertTrue(fieldNames.contains("field-name1")); + assertTrue(fieldNames.contains("field-name2")); + } + + public void testIsAdmin() { + User user1 = new User( + randomAlphaOfLength(5), + ImmutableList.of(), + ImmutableList.of("all_access"), + ImmutableList.of(randomAlphaOfLength(5)) + ); + assertTrue(isAdmin(user1)); + } + + public void testIsAdminBackendRoleIsAllAccess() { + String backendRole1 = "all_access"; + User user1 = new User( + randomAlphaOfLength(5), + ImmutableList.of(backendRole1), + ImmutableList.of(randomAlphaOfLength(5)), + ImmutableList.of(randomAlphaOfLength(5)) + ); + assertFalse(isAdmin(user1)); + } + + public void testIsAdminNull() { + assertFalse(isAdmin(null)); + } +} diff --git a/src/test/java/org/opensearch/ad/util/RestHandlerUtilsTests.java b/src/test/java/org/opensearch/ad/util/RestHandlerUtilsTests.java index c9a18468d..ecd60e5d4 100644 --- a/src/test/java/org/opensearch/ad/util/RestHandlerUtilsTests.java +++ b/src/test/java/org/opensearch/ad/util/RestHandlerUtilsTests.java @@ -18,7 +18,7 @@ import java.io.IOException; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.common.bytes.BytesReference; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; diff --git a/src/test/java/org/opensearch/ad/util/RestHandlerUtilsTests.java-e b/src/test/java/org/opensearch/ad/util/RestHandlerUtilsTests.java-e new file mode 100644 index 000000000..ecd60e5d4 --- /dev/null +++ b/src/test/java/org/opensearch/ad/util/RestHandlerUtilsTests.java-e @@ -0,0 +1,114 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import static org.opensearch.timeseries.TestHelpers.builder; +import static org.opensearch.timeseries.TestHelpers.randomFeature; +import static org.opensearch.timeseries.util.RestHandlerUtils.OPENSEARCH_DASHBOARDS_USER_AGENT; + +import java.io.IOException; + +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestRequest; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestChannel; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.util.RestHandlerUtils; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +public class RestHandlerUtilsTests extends OpenSearchTestCase { + + public void testGetSourceContextFromOpenSearchDashboardEmptyExcludes() { + FakeRestRequest.Builder builder = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY); + builder.withHeaders(ImmutableMap.of("User-Agent", ImmutableList.of(OPENSEARCH_DASHBOARDS_USER_AGENT, randomAlphaOfLength(10)))); + SearchSourceBuilder testSearchSourceBuilder = new SearchSourceBuilder(); + testSearchSourceBuilder.fetchSource(new String[] { "a" }, new String[0]); + FetchSourceContext sourceContext = RestHandlerUtils.getSourceContext(builder.build(), testSearchSourceBuilder); + assertArrayEquals(new String[] { "a" }, sourceContext.includes()); + assertEquals(0, sourceContext.excludes().length); + assertEquals(1, sourceContext.includes().length); + } + + public void testGetSourceContextFromClientWithExcludes() { + FakeRestRequest.Builder builder = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY); + SearchSourceBuilder testSearchSourceBuilder = new SearchSourceBuilder(); + testSearchSourceBuilder.fetchSource(new String[] { "a" }, new String[] { "b" }); + FetchSourceContext sourceContext = RestHandlerUtils.getSourceContext(builder.build(), testSearchSourceBuilder); + assertEquals(sourceContext.excludes().length, 2); + } + + public void testGetSourceContextFromClientWithoutSource() { + FakeRestRequest.Builder builder = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY); + SearchSourceBuilder testSearchSourceBuilder = new SearchSourceBuilder(); + FetchSourceContext sourceContext = RestHandlerUtils.getSourceContext(builder.build(), testSearchSourceBuilder); + assertEquals(sourceContext.excludes().length, 1); + assertEquals(sourceContext.includes().length, 0); + } + + public void testGetSourceContextOpenSearchDashboardWithoutSources() { + FakeRestRequest.Builder builder = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY); + builder.withHeaders(ImmutableMap.of("User-Agent", ImmutableList.of(OPENSEARCH_DASHBOARDS_USER_AGENT, randomAlphaOfLength(10)))); + SearchSourceBuilder testSearchSourceBuilder = new SearchSourceBuilder(); + FetchSourceContext sourceContext = RestHandlerUtils.getSourceContext(builder.build(), testSearchSourceBuilder); + assertNull(sourceContext); + } + + public void testCreateXContentParser() throws IOException { + RestRequest request = new FakeRestRequest(); + RestChannel channel = new FakeRestChannel(request, false, 1); + XContentBuilder builder = builder().startObject().field("test", "value").endObject(); + BytesReference bytesReference = BytesReference.bytes(builder); + XContentParser parser = RestHandlerUtils.createXContentParser(channel, bytesReference); + parser.close(); + } + + public void testisExceptionCausedByInvalidQueryNotSearchPhaseException() { + assertFalse(RestHandlerUtils.isExceptionCausedByInvalidQuery(new IllegalArgumentException())); + } + + public void testValidateAnomalyDetectorWithTooManyFeatures() throws IOException { + AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableList.of(randomFeature(), randomFeature())); + String error = RestHandlerUtils.checkFeaturesSyntax(detector, 1); + assertEquals("Can't create more than 1 features", error); + } + + public void testValidateAnomalyDetectorWithDuplicateFeatureNames() throws IOException { + String featureName = randomAlphaOfLength(5); + AnomalyDetector detector = TestHelpers + .randomAnomalyDetector( + ImmutableList.of(randomFeature(featureName, randomAlphaOfLength(5)), randomFeature(featureName, randomAlphaOfLength(5))) + ); + String error = RestHandlerUtils.checkFeaturesSyntax(detector, 2); + assertEquals("There are duplicate feature names: " + featureName, error); + } + + public void testValidateAnomalyDetectorWithDuplicateAggregationNames() throws IOException { + String aggregationName = randomAlphaOfLength(5); + AnomalyDetector detector = TestHelpers + .randomAnomalyDetector( + ImmutableList + .of(randomFeature(randomAlphaOfLength(5), aggregationName), randomFeature(randomAlphaOfLength(5), aggregationName)) + ); + String error = RestHandlerUtils.checkFeaturesSyntax(detector, 2); + assertEquals("Config has duplicate feature aggregation query names: " + aggregationName, error); + } +} diff --git a/src/test/java/org/opensearch/ad/util/ThrottlerTests.java-e b/src/test/java/org/opensearch/ad/util/ThrottlerTests.java-e new file mode 100644 index 000000000..61bb19ec8 --- /dev/null +++ b/src/test/java/org/opensearch/ad/util/ThrottlerTests.java-e @@ -0,0 +1,67 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import static org.mockito.Mockito.mock; +import static org.powermock.api.mockito.PowerMockito.when; + +import java.time.Clock; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.test.OpenSearchTestCase; + +public class ThrottlerTests extends OpenSearchTestCase { + private Throttler throttler; + + @Before + public void setup() { + Clock clock = mock(Clock.class); + this.throttler = new Throttler(clock); + } + + @Test + public void testGetFilteredQuery() { + AnomalyDetector detector = mock(AnomalyDetector.class); + when(detector.getId()).thenReturn("test detector Id"); + SearchRequest dummySearchRequest = new SearchRequest(); + throttler.insertFilteredQuery(detector.getId(), dummySearchRequest); + // case 1: key exists + assertTrue(throttler.getFilteredQuery(detector.getId()).isPresent()); + // case 2: key doesn't exist + assertFalse(throttler.getFilteredQuery("different test detector Id").isPresent()); + } + + @Test + public void testInsertFilteredQuery() { + AnomalyDetector detector = mock(AnomalyDetector.class); + when(detector.getId()).thenReturn("test detector Id"); + SearchRequest dummySearchRequest = new SearchRequest(); + // first time: key doesn't exist + assertTrue(throttler.insertFilteredQuery(detector.getId(), dummySearchRequest)); + // second time: key exists + assertFalse(throttler.insertFilteredQuery(detector.getId(), dummySearchRequest)); + } + + @Test + public void testClearFilteredQuery() { + AnomalyDetector detector = mock(AnomalyDetector.class); + when(detector.getId()).thenReturn("test detector Id"); + SearchRequest dummySearchRequest = new SearchRequest(); + assertTrue(throttler.insertFilteredQuery(detector.getId(), dummySearchRequest)); + throttler.clearFilteredQuery(detector.getId()); + assertTrue(throttler.insertFilteredQuery(detector.getId(), dummySearchRequest)); + } + +} diff --git a/src/test/java/org/opensearch/ad/util/ThrowingSupplierWrapperTests.java-e b/src/test/java/org/opensearch/ad/util/ThrowingSupplierWrapperTests.java-e new file mode 100644 index 000000000..f7db4a278 --- /dev/null +++ b/src/test/java/org/opensearch/ad/util/ThrowingSupplierWrapperTests.java-e @@ -0,0 +1,30 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.util; + +import java.io.IOException; + +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.function.ThrowingSupplierWrapper; + +public class ThrowingSupplierWrapperTests extends OpenSearchTestCase { + private static String foo() throws IOException { + throw new IOException("blah"); + } + + public void testExceptionThrown() { + expectThrows( + RuntimeException.class, + () -> ThrowingSupplierWrapper.throwingSupplierWrapper(ThrowingSupplierWrapperTests::foo).get() + ); + } +} diff --git a/src/test/java/org/opensearch/forecast/indices/ForecastIndexManagementTests.java b/src/test/java/org/opensearch/forecast/indices/ForecastIndexManagementTests.java index 366db213a..bd1577fa2 100644 --- a/src/test/java/org/opensearch/forecast/indices/ForecastIndexManagementTests.java +++ b/src/test/java/org/opensearch/forecast/indices/ForecastIndexManagementTests.java @@ -26,13 +26,13 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.admin.indices.alias.get.GetAliasesResponse; import org.opensearch.action.admin.indices.get.GetIndexResponse; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.index.IndexNotFoundException; import org.opensearch.plugins.Plugin; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.function.ExecutorFunction; import org.opensearch.timeseries.indices.IndexManagementIntegTestCase; import org.opensearch.timeseries.settings.TimeSeriesSettings; @@ -49,11 +49,12 @@ protected boolean ignoreExternalCluster() { return true; } - // help register setting using AnomalyDetectorPlugin.getSettings. Otherwise, AnomalyDetectionIndices's constructor would fail due to - // unregistered settings like AD_RESULT_HISTORY_MAX_DOCS. + // help register setting using TimeSeriesAnalyticsPlugin.getSettings. + // Otherwise, ForecastIndexManagement's constructor would fail due to + // unregistered settings like FORECAST_RESULT_HISTORY_MAX_DOCS_PER_SHARD. @Override protected Collection> nodePlugins() { - return Collections.singletonList(AnomalyDetectorPlugin.class); + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); } @Before diff --git a/src/test/java/org/opensearch/forecast/indices/ForecastIndexManagementTests.java-e b/src/test/java/org/opensearch/forecast/indices/ForecastIndexManagementTests.java-e new file mode 100644 index 000000000..bd1577fa2 --- /dev/null +++ b/src/test/java/org/opensearch/forecast/indices/ForecastIndexManagementTests.java-e @@ -0,0 +1,351 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.indices; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.util.Collection; +import java.util.Collections; +import java.util.Locale; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.hamcrest.MatcherAssert; +import org.junit.Before; +import org.opensearch.action.ActionListener; +import org.opensearch.action.admin.indices.alias.get.GetAliasesResponse; +import org.opensearch.action.admin.indices.get.GetIndexResponse; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.plugins.Plugin; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.indices.IndexManagementIntegTestCase; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 0, numClientNodes = 0, supportsDedicatedMasters = false) +public class ForecastIndexManagementTests extends IndexManagementIntegTestCase { + private ForecastIndexManagement indices; + private Settings settings; + private DiscoveryNodeFilterer nodeFilter; + + @Override + protected boolean ignoreExternalCluster() { + return true; + } + + // help register setting using TimeSeriesAnalyticsPlugin.getSettings. + // Otherwise, ForecastIndexManagement's constructor would fail due to + // unregistered settings like FORECAST_RESULT_HISTORY_MAX_DOCS_PER_SHARD. + @Override + protected Collection> nodePlugins() { + return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); + } + + @Before + public void setup() throws IOException { + settings = Settings + .builder() + .put("plugins.forecast.forecast_result_history_rollover_period", TimeValue.timeValueHours(12)) + .put("plugins.forecast.forecast_result_history_retention_period", TimeValue.timeValueHours(24)) + .put("plugins.forecast.forecast_result_history_max_docs", 10000L) + .put("plugins.forecast.request_timeout", TimeValue.timeValueSeconds(10)) + .build(); + + internalCluster().ensureAtLeastNumDataNodes(1); + ensureStableCluster(1); + + nodeFilter = new DiscoveryNodeFilterer(clusterService()); + + indices = new ForecastIndexManagement( + client(), + clusterService(), + client().threadPool(), + settings, + nodeFilter, + TimeSeriesSettings.MAX_UPDATE_RETRY_TIMES + ); + } + + public void testForecastResultIndexNotExists() { + boolean exists = indices.doesDefaultResultIndexExist(); + assertFalse(exists); + } + + public void testForecastResultIndexExists() throws IOException { + indices.initDefaultResultIndexIfAbsent(TestHelpers.createActionListener(response -> { + boolean acknowledged = response.isAcknowledged(); + assertTrue(acknowledged); + }, failure -> { throw new RuntimeException("should not recreate index"); })); + TestHelpers.waitForIndexCreationToComplete(client(), ForecastIndex.RESULT.getIndexName()); + assertTrue(indices.doesDefaultResultIndexExist()); + } + + public void testForecastResultIndexExistsAndNotRecreate() throws IOException { + indices + .initDefaultResultIndexIfAbsent( + TestHelpers + .createActionListener( + response -> logger.info("Acknowledged: " + response.isAcknowledged()), + failure -> { throw new RuntimeException("should not recreate index"); } + ) + ); + TestHelpers.waitForIndexCreationToComplete(client(), ForecastIndex.RESULT.getIndexName()); + if (client().admin().indices().prepareExists(ForecastIndex.RESULT.getIndexName()).get().isExists()) { + indices + .initDefaultResultIndexIfAbsent( + TestHelpers + .createActionListener( + response -> { throw new RuntimeException("should not recreate index " + ForecastIndex.RESULT.getIndexName()); }, + failure -> { + throw new RuntimeException("should not recreate index " + ForecastIndex.RESULT.getIndexName(), failure); + } + ) + ); + } + } + + public void testCheckpointIndexNotExists() { + boolean exists = indices.doesCheckpointIndexExist(); + assertFalse(exists); + } + + public void testCheckpointIndexExists() throws IOException { + indices.initCheckpointIndex(TestHelpers.createActionListener(response -> { + boolean acknowledged = response.isAcknowledged(); + assertTrue(acknowledged); + }, failure -> { throw new RuntimeException("should not recreate index"); })); + TestHelpers.waitForIndexCreationToComplete(client(), ForecastIndex.STATE.getIndexName()); + assertTrue(indices.doesCheckpointIndexExist()); + } + + public void testStateIndexNotExists() { + boolean exists = indices.doesStateIndexExist(); + assertFalse(exists); + } + + public void testStateIndexExists() throws IOException { + indices.initStateIndex(TestHelpers.createActionListener(response -> { + boolean acknowledged = response.isAcknowledged(); + assertTrue(acknowledged); + }, failure -> { throw new RuntimeException("should not recreate index"); })); + TestHelpers.waitForIndexCreationToComplete(client(), ForecastIndex.STATE.getIndexName()); + assertTrue(indices.doesStateIndexExist()); + } + + public void testConfigIndexNotExists() { + boolean exists = indices.doesConfigIndexExist(); + assertFalse(exists); + } + + public void testConfigIndexExists() throws IOException { + indices.initConfigIndex(TestHelpers.createActionListener(response -> { + boolean acknowledged = response.isAcknowledged(); + assertTrue(acknowledged); + }, failure -> { throw new RuntimeException("should not recreate index"); })); + TestHelpers.waitForIndexCreationToComplete(client(), ForecastIndex.CONFIG.getIndexName()); + assertTrue(indices.doesConfigIndexExist()); + } + + public void testCustomResultIndexExists() throws IOException { + String indexName = "a"; + assertTrue(!(client().admin().indices().prepareExists(indexName).get().isExists())); + indices + .initCustomResultIndexDirectly( + indexName, + TestHelpers + .createActionListener( + response -> logger.info("Acknowledged: " + response.isAcknowledged()), + failure -> { throw new RuntimeException("should not recreate index"); } + ) + ); + TestHelpers.waitForIndexCreationToComplete(client(), indexName); + assertTrue((client().admin().indices().prepareExists(indexName).get().isExists())); + } + + public void testJobIndexNotExists() { + boolean exists = indices.doesJobIndexExist(); + assertFalse(exists); + } + + public void testJobIndexExists() throws IOException { + indices.initJobIndex(TestHelpers.createActionListener(response -> { + boolean acknowledged = response.isAcknowledged(); + assertTrue(acknowledged); + }, failure -> { throw new RuntimeException("should not recreate index"); })); + TestHelpers.waitForIndexCreationToComplete(client(), ForecastIndex.JOB.getIndexName()); + assertTrue(indices.doesJobIndexExist()); + } + + public void testValidateCustomIndexForBackendJobNoIndex() { + validateCustomIndexForBackendJobNoIndex(indices); + } + + public void testValidateCustomIndexForBackendJobInvalidMapping() { + validateCustomIndexForBackendJobInvalidMapping(indices); + } + + public void testValidateCustomIndexForBackendJob() throws IOException, InterruptedException { + validateCustomIndexForBackendJob(indices, ForecastIndexManagement.getResultMappings()); + } + + public void testRollOver() throws IOException, InterruptedException { + indices.initDefaultResultIndexIfAbsent(TestHelpers.createActionListener(response -> { + boolean acknowledged = response.isAcknowledged(); + assertTrue(acknowledged); + }, failure -> { throw new RuntimeException("should not recreate index"); })); + TestHelpers.waitForIndexCreationToComplete(client(), ForecastIndex.RESULT.getIndexName()); + client().index(indices.createDummyIndexRequest(ForecastIndex.RESULT.getIndexName())).actionGet(); + + GetAliasesResponse getAliasesResponse = admin().indices().prepareGetAliases(ForecastIndex.RESULT.getIndexName()).get(); + String oldIndex = getAliasesResponse.getAliases().keySet().iterator().next(); + + settings = Settings + .builder() + .put("plugins.forecast.forecast_result_history_rollover_period", TimeValue.timeValueHours(12)) + .put("plugins.forecast.forecast_result_history_retention_period", TimeValue.timeValueHours(0)) + .put("plugins.forecast.forecast_result_history_max_docs", 0L) + .put("plugins.forecast.forecast_result_history_max_docs_per_shard", 0L) + .put("plugins.forecast.request_timeout", TimeValue.timeValueSeconds(10)) + .build(); + + nodeFilter = new DiscoveryNodeFilterer(clusterService()); + + indices = new ForecastIndexManagement( + client(), + clusterService(), + client().threadPool(), + settings, + nodeFilter, + TimeSeriesSettings.MAX_UPDATE_RETRY_TIMES + ); + indices.rolloverAndDeleteHistoryIndex(); + + // replace the last two characters "-1" to "000002"? + // Example: + // Input: opensearch-forecast-results-history-2023.06.15-1 + // Output: opensearch-forecast-results-history-2023.06.15-000002 + String newIndex = oldIndex.replaceFirst("-1$", "-000002"); + TestHelpers.waitForIndexCreationToComplete(client(), newIndex); + + getAliasesResponse = admin().indices().prepareGetAliases(ForecastIndex.RESULT.getIndexName()).get(); + String currentPointedIndex = getAliasesResponse.getAliases().keySet().iterator().next(); + assertEquals(newIndex, currentPointedIndex); + + client().index(indices.createDummyIndexRequest(ForecastIndex.RESULT.getIndexName())).actionGet(); + // now we have two indices + indices.rolloverAndDeleteHistoryIndex(); + + String thirdIndexName = getIncrementedIndex(newIndex); + TestHelpers.waitForIndexCreationToComplete(client(), thirdIndexName); + getAliasesResponse = admin().indices().prepareGetAliases(ForecastIndex.RESULT.getIndexName()).get(); + currentPointedIndex = getAliasesResponse.getAliases().keySet().iterator().next(); + assertEquals(thirdIndexName, currentPointedIndex); + + // we have already deleted the oldest index since retention period is 0 hrs + int retry = 0; + while (retry < 10) { + try { + client().admin().indices().prepareGetIndex().addIndices(oldIndex).get(); + retry++; + // wait for index to be deleted + Thread.sleep(1000); + } catch (IndexNotFoundException e) { + MatcherAssert.assertThat(e.getMessage(), is(String.format(Locale.ROOT, "no such index [%s]", oldIndex))); + break; + } + } + + assertTrue(retry < 20); + + // 2nd oldest index should be fine as we keep at one old index + GetIndexResponse response = client().admin().indices().prepareGetIndex().addIndices(newIndex).get(); + String[] indicesInResponse = response.indices(); + MatcherAssert.assertThat(indicesInResponse, notNullValue()); + MatcherAssert.assertThat(indicesInResponse.length, equalTo(1)); + MatcherAssert.assertThat(indicesInResponse[0], equalTo(newIndex)); + + response = client().admin().indices().prepareGetIndex().addIndices(thirdIndexName).get(); + indicesInResponse = response.indices(); + MatcherAssert.assertThat(indicesInResponse, notNullValue()); + MatcherAssert.assertThat(indicesInResponse.length, equalTo(1)); + MatcherAssert.assertThat(indicesInResponse[0], equalTo(thirdIndexName)); + } + + /** + * Increment the last digit oif an index name. + * @param input. Example: opensearch-forecast-results-history-2023.06.15-000002 + * @return Example: opensearch-forecast-results-history-2023.06.15-000003 + */ + private String getIncrementedIndex(String input) { + int lastDash = input.lastIndexOf('-'); + + String prefix = input.substring(0, lastDash + 1); + String numberPart = input.substring(lastDash + 1); + + // Increment the number part + int incrementedNumber = Integer.parseInt(numberPart) + 1; + + // Use String.format to keep the leading zeros + String newNumberPart = String.format(Locale.ROOT, "%06d", incrementedNumber); + + return prefix + newNumberPart; + } + + public void testInitCustomResultIndexAndExecuteIndexNotExist() throws InterruptedException { + String resultIndex = "abc"; + ExecutorFunction function = mock(ExecutorFunction.class); + ActionListener listener = mock(ActionListener.class); + + CountDownLatch latch = new CountDownLatch(1); + doAnswer(invocation -> { + latch.countDown(); + return null; + }).when(function).execute(); + + indices.initCustomResultIndexAndExecute(resultIndex, function, listener); + latch.await(20, TimeUnit.SECONDS); + verify(listener, never()).onFailure(any(Exception.class)); + } + + public void testInitCustomResultIndexAndExecuteIndex() throws InterruptedException, IOException { + String indexName = "abc"; + ExecutorFunction function = mock(ExecutorFunction.class); + ActionListener listener = mock(ActionListener.class); + + indices + .initCustomResultIndexDirectly( + indexName, + TestHelpers + .createActionListener( + response -> logger.info("Acknowledged: " + response.isAcknowledged()), + failure -> { throw new RuntimeException("should not recreate index"); } + ) + ); + TestHelpers.waitForIndexCreationToComplete(client(), indexName); + CountDownLatch latch = new CountDownLatch(1); + doAnswer(invocation -> { + latch.countDown(); + return null; + }).when(function).execute(); + + indices.initCustomResultIndexAndExecute(indexName, function, listener); + latch.await(20, TimeUnit.SECONDS); + verify(listener, never()).onFailure(any(Exception.class)); + } +} diff --git a/src/test/java/org/opensearch/forecast/indices/ForecastIndexMappingTests.java-e b/src/test/java/org/opensearch/forecast/indices/ForecastIndexMappingTests.java-e new file mode 100644 index 000000000..a79eda373 --- /dev/null +++ b/src/test/java/org/opensearch/forecast/indices/ForecastIndexMappingTests.java-e @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.indices; + +import java.io.IOException; + +import org.opensearch.test.OpenSearchTestCase; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +public class ForecastIndexMappingTests extends OpenSearchTestCase { + + public void testGetForecastResultMappings() throws IOException { + String mapping = ForecastIndexManagement.getResultMappings(); + + // Use Jackson to convert the string into a JsonNode + ObjectMapper mapper = new ObjectMapper(); + JsonNode mappingJson = mapper.readTree(mapping); + + // Check the existence of some fields + assertTrue("forecaster_id field is missing", mappingJson.path("properties").has("forecaster_id")); + assertTrue("feature_data field is missing", mappingJson.path("properties").has("feature_data")); + assertTrue("data_start_time field is missing", mappingJson.path("properties").has("data_start_time")); + assertTrue("execution_start_time field is missing", mappingJson.path("properties").has("execution_start_time")); + assertTrue("user field is missing", mappingJson.path("properties").has("user")); + assertTrue("entity field is missing", mappingJson.path("properties").has("entity")); + assertTrue("schema_version field is missing", mappingJson.path("properties").has("schema_version")); + assertTrue("task_id field is missing", mappingJson.path("properties").has("task_id")); + assertTrue("model_id field is missing", mappingJson.path("properties").has("model_id")); + assertTrue("forecast_series field is missing", mappingJson.path("properties").has("forecast_value")); + } + + public void testGetCheckpointMappings() throws IOException { + String mapping = ForecastIndexManagement.getCheckpointMappings(); + + // Use Jackson to convert the string into a JsonNode + ObjectMapper mapper = new ObjectMapper(); + JsonNode mappingJson = mapper.readTree(mapping); + + // Check the existence of some fields + assertTrue("forecaster_id field is missing", mappingJson.path("properties").has("forecaster_id")); + assertTrue("timestamp field is missing", mappingJson.path("properties").has("timestamp")); + assertTrue("schema_version field is missing", mappingJson.path("properties").has("schema_version")); + assertTrue("entity field is missing", mappingJson.path("properties").has("entity")); + assertTrue("model field is missing", mappingJson.path("properties").has("model")); + assertTrue("samples field is missing", mappingJson.path("properties").has("samples")); + assertTrue("last_processed_sample field is missing", mappingJson.path("properties").has("last_processed_sample")); + } + + public void testGetStateMappings() throws IOException { + String mapping = ForecastIndexManagement.getStateMappings(); + + // Use Jackson to convert the string into a JsonNode + ObjectMapper mapper = new ObjectMapper(); + JsonNode mappingJson = mapper.readTree(mapping); + + // Check the existence of some fields + assertTrue("schema_version field is missing", mappingJson.path("properties").has("schema_version")); + assertTrue("last_update_time field is missing", mappingJson.path("properties").has("last_update_time")); + assertTrue("error field is missing", mappingJson.path("properties").has("error")); + assertTrue("started_by field is missing", mappingJson.path("properties").has("started_by")); + assertTrue("stopped_by field is missing", mappingJson.path("properties").has("stopped_by")); + assertTrue("forecaster_id field is missing", mappingJson.path("properties").has("forecaster_id")); + assertTrue("state field is missing", mappingJson.path("properties").has("state")); + assertTrue("task_progress field is missing", mappingJson.path("properties").has("task_progress")); + assertTrue("init_progress field is missing", mappingJson.path("properties").has("init_progress")); + assertTrue("current_piece field is missing", mappingJson.path("properties").has("current_piece")); + assertTrue("execution_start_time field is missing", mappingJson.path("properties").has("execution_start_time")); + assertTrue("execution_end_time field is missing", mappingJson.path("properties").has("execution_end_time")); + assertTrue("is_latest field is missing", mappingJson.path("properties").has("is_latest")); + assertTrue("task_type field is missing", mappingJson.path("properties").has("task_type")); + assertTrue("checkpoint_id field is missing", mappingJson.path("properties").has("checkpoint_id")); + assertTrue("coordinating_node field is missing", mappingJson.path("properties").has("coordinating_node")); + assertTrue("worker_node field is missing", mappingJson.path("properties").has("worker_node")); + assertTrue("user field is missing", mappingJson.path("properties").has("user")); + assertTrue("forecaster field is missing", mappingJson.path("properties").has("forecaster")); + assertTrue("date_range field is missing", mappingJson.path("properties").has("date_range")); + assertTrue("parent_task_id field is missing", mappingJson.path("properties").has("parent_task_id")); + assertTrue("entity field is missing", mappingJson.path("properties").has("entity")); + assertTrue("estimated_minutes_left field is missing", mappingJson.path("properties").has("estimated_minutes_left")); + } + +} diff --git a/src/test/java/org/opensearch/forecast/indices/ForecastResultIndexTests.java b/src/test/java/org/opensearch/forecast/indices/ForecastResultIndexTests.java index 7b537de44..0f9158ba4 100644 --- a/src/test/java/org/opensearch/forecast/indices/ForecastResultIndexTests.java +++ b/src/test/java/org/opensearch/forecast/indices/ForecastResultIndexTests.java @@ -39,9 +39,9 @@ import org.opensearch.common.UUIDs; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.core.index.Index; import org.opensearch.env.Environment; import org.opensearch.forecast.settings.ForecastSettings; -import org.opensearch.index.Index; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.common.exception.EndRunException; diff --git a/src/test/java/org/opensearch/forecast/indices/ForecastResultIndexTests.java-e b/src/test/java/org/opensearch/forecast/indices/ForecastResultIndexTests.java-e new file mode 100644 index 000000000..bdc15f141 --- /dev/null +++ b/src/test/java/org/opensearch/forecast/indices/ForecastResultIndexTests.java-e @@ -0,0 +1,229 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.indices; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; + +import org.mockito.ArgumentCaptor; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.admin.cluster.state.ClusterStateRequest; +import org.opensearch.action.admin.cluster.state.ClusterStateResponse; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.UUIDs; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.env.Environment; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.core.index.Index; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +public class ForecastResultIndexTests extends AbstractTimeSeriesTest { + private ForecastIndexManagement forecastIndices; + private IndicesAdminClient indicesClient; + private ClusterAdminClient clusterAdminClient; + private ClusterName clusterName; + private ClusterState clusterState; + private ClusterService clusterService; + private long defaultMaxDocs; + private int numberOfNodes; + private Client client; + + @Override + public void setUp() throws Exception { + super.setUp(); + client = mock(Client.class); + indicesClient = mock(IndicesAdminClient.class); + AdminClient adminClient = mock(AdminClient.class); + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings( + Settings.EMPTY, + Collections + .unmodifiableSet( + new HashSet<>( + Arrays + .asList( + ForecastSettings.FORECAST_RESULT_HISTORY_MAX_DOCS_PER_SHARD, + ForecastSettings.FORECAST_RESULT_HISTORY_ROLLOVER_PERIOD, + ForecastSettings.FORECAST_RESULT_HISTORY_RETENTION_PERIOD, + ForecastSettings.FORECAST_MAX_PRIMARY_SHARDS + ) + ) + ) + ); + + clusterName = new ClusterName("test"); + + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + + ThreadPool threadPool = mock(ThreadPool.class); + Settings settings = Settings.EMPTY; + when(client.admin()).thenReturn(adminClient); + when(adminClient.indices()).thenReturn(indicesClient); + + DiscoveryNodeFilterer nodeFilter = mock(DiscoveryNodeFilterer.class); + numberOfNodes = 2; + when(nodeFilter.getNumberOfEligibleDataNodes()).thenReturn(numberOfNodes); + + forecastIndices = new ForecastIndexManagement( + client, + clusterService, + threadPool, + settings, + nodeFilter, + TimeSeriesSettings.MAX_UPDATE_RETRY_TIMES + ); + + clusterAdminClient = mock(ClusterAdminClient.class); + when(adminClient.cluster()).thenReturn(clusterAdminClient); + + doAnswer(invocation -> { + ClusterStateRequest clusterStateRequest = invocation.getArgument(0); + assertEquals(ForecastIndexManagement.ALL_FORECAST_RESULTS_INDEX_PATTERN, clusterStateRequest.indices()[0]); + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArgument(1); + listener.onResponse(new ClusterStateResponse(clusterName, clusterState, true)); + return null; + }).when(clusterAdminClient).state(any(), any()); + + defaultMaxDocs = ForecastSettings.FORECAST_RESULT_HISTORY_MAX_DOCS_PER_SHARD.getDefault(Settings.EMPTY); + + clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().build()).build(); + when(clusterService.state()).thenReturn(clusterState); + } + + public void testMappingSetToUpdated() throws IOException { + try { + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArgument(1); + listener.onResponse(new CreateIndexResponse(true, true, "blah")); + return null; + }).when(indicesClient).create(any(), any()); + + super.setUpLog4jForJUnit(IndexManagement.class); + + ActionListener listener = mock(ActionListener.class); + forecastIndices.initDefaultResultIndexDirectly(listener); + verify(listener, times(1)).onResponse(any(CreateIndexResponse.class)); + assertTrue(testAppender.containsMessage("mapping up-to-date")); + } finally { + super.tearDownLog4jForJUnit(); + } + + } + + public void testInitCustomResultIndexNoAck() { + ExecutorFunction function = mock(ExecutorFunction.class); + ActionListener listener = mock(ActionListener.class); + + doAnswer(invocation -> { + ActionListener createIndexListener = (ActionListener) invocation.getArgument(1); + createIndexListener.onResponse(new CreateIndexResponse(false, false, "blah")); + return null; + }).when(indicesClient).create(any(), any()); + + ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); + forecastIndices.initCustomResultIndexAndExecute("abc", function, listener); + verify(listener, times(1)).onFailure(response.capture()); + Exception value = response.getValue(); + assertTrue(value instanceof EndRunException); + assertTrue( + "actual: " + value.getMessage(), + value.getMessage().contains("Creating result index with mappings call not acknowledged") + ); + + } + + public void testInitCustomResultIndexAlreadyExist() throws IOException { + ExecutorFunction function = mock(ExecutorFunction.class); + ActionListener listener = mock(ActionListener.class); + + String indexName = "abc"; + + Settings settings = Settings + .builder() + .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) + .put(IndexMetadata.SETTING_INDEX_UUID, UUIDs.randomBase64UUID()) + .put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()) + .build(); + IndexMetadata indexMetaData = IndexMetadata + .builder(indexName) + .settings(settings) + .putMapping(ForecastIndexManagement.getResultMappings()) + .build(); + final Map indices = new HashMap<>(); + indices.put(indexName, indexMetaData); + + clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().indices(indices).build()).build(); + when(clusterService.state()).thenReturn(clusterState); + + doAnswer(invocation -> { + ActionListener createIndexListener = (ActionListener) invocation.getArgument(1); + createIndexListener.onFailure(new ResourceAlreadyExistsException(new Index(indexName, indexName))); + return null; + }).when(indicesClient).create(any(), any()); + + forecastIndices.initCustomResultIndexAndExecute(indexName, function, listener); + verify(listener, never()).onFailure(any()); + } + + public void testInitCustomResultIndexUnknownException() throws IOException { + ExecutorFunction function = mock(ExecutorFunction.class); + ActionListener listener = mock(ActionListener.class); + + String indexName = "abc"; + String exceptionMsg = "blah"; + + doAnswer(invocation -> { + ActionListener createIndexListener = (ActionListener) invocation.getArgument(1); + createIndexListener.onFailure(new IllegalArgumentException(exceptionMsg)); + return null; + }).when(indicesClient).create(any(), any()); + super.setUpLog4jForJUnit(IndexManagement.class); + try { + forecastIndices.initCustomResultIndexAndExecute(indexName, function, listener); + ArgumentCaptor response = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(response.capture()); + + Exception value = response.getValue(); + assertTrue(value instanceof IllegalArgumentException); + assertTrue("actual: " + value.getMessage(), value.getMessage().contains(exceptionMsg)); + } finally { + super.tearDownLog4jForJUnit(); + } + } +} diff --git a/src/test/java/org/opensearch/forecast/model/ForecastResultTests.java-e b/src/test/java/org/opensearch/forecast/model/ForecastResultTests.java-e new file mode 100644 index 000000000..8e7c4e17a --- /dev/null +++ b/src/test/java/org/opensearch/forecast/model/ForecastResultTests.java-e @@ -0,0 +1,103 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.model; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import org.junit.Before; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.FeatureData; + +public class ForecastResultTests extends OpenSearchTestCase { + List result; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + // Arrange + String forecasterId = "testId"; + long intervalMillis = 1000; + Double dataQuality = 0.9; + List featureData = new ArrayList<>(); + featureData.add(new FeatureData("f1", "f1", 1.0d)); + featureData.add(new FeatureData("f2", "f2", 2.0d)); + long currentTimeMillis = System.currentTimeMillis(); + Instant instantFromMillis = Instant.ofEpochMilli(currentTimeMillis); + Instant dataStartTime = instantFromMillis; + Instant dataEndTime = dataStartTime.plusSeconds(10); + Instant executionStartTime = instantFromMillis; + Instant executionEndTime = executionStartTime.plusSeconds(10); + String error = null; + Optional entity = Optional.empty(); + User user = new User("testUser", Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); + Integer schemaVersion = 1; + String modelId = "testModelId"; + float[] forecastsValues = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; + float[] forecastsUppers = new float[] { 1.5f, 2.5f, 3.5f, 4.5f }; + float[] forecastsLowers = new float[] { 0.5f, 1.5f, 2.5f, 3.5f }; + String taskId = "testTaskId"; + + // Act + result = ForecastResult + .fromRawRCFCasterResult( + forecasterId, + intervalMillis, + dataQuality, + featureData, + dataStartTime, + dataEndTime, + executionStartTime, + executionEndTime, + error, + entity, + user, + schemaVersion, + modelId, + forecastsValues, + forecastsUppers, + forecastsLowers, + taskId + ); + } + + public void testFromRawRCFCasterResult() { + // Assert + assertEquals(5, result.size()); + assertEquals("f1", result.get(1).getFeatureId()); + assertEquals(1.0f, result.get(1).getForecastValue(), 0.01); + assertEquals("f2", result.get(2).getFeatureId()); + assertEquals(2.0f, result.get(2).getForecastValue(), 0.01); + + assertTrue( + "actual: " + result.toString(), + result + .toString() + .contains( + "featureId=f2,dataQuality=0.9,forecastValue=2.0,lowerBound=1.5,upperBound=2.5,confidenceIntervalWidth=1.0,forecastDataStartTime=" + ) + ); + } + + public void testParseAnomalyDetector() throws IOException { + for (int i = 0; i < 5; i++) { + String forecastResultString = TestHelpers + .xContentBuilderToString(result.get(i).toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + ForecastResult parsedForecastResult = ForecastResult.parse(TestHelpers.parser(forecastResultString)); + assertEquals("Parsing forecast result doesn't work", result.get(i), parsedForecastResult); + assertTrue("Parsing forecast result doesn't work", result.get(i).hashCode() == parsedForecastResult.hashCode()); + } + } +} diff --git a/src/test/java/org/opensearch/forecast/model/ForecastSerializationTests.java b/src/test/java/org/opensearch/forecast/model/ForecastSerializationTests.java index e7adbfc63..e83fce6d9 100644 --- a/src/test/java/org/opensearch/forecast/model/ForecastSerializationTests.java +++ b/src/test/java/org/opensearch/forecast/model/ForecastSerializationTests.java @@ -8,19 +8,19 @@ import java.io.IOException; import java.util.Collection; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; -import org.opensearch.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.plugins.Plugin; import org.opensearch.test.InternalSettingsPlugin; import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; public class ForecastSerializationTests extends OpenSearchSingleNodeTestCase { @Override protected Collection> getPlugins() { - return pluginList(InternalSettingsPlugin.class, AnomalyDetectorPlugin.class); + return pluginList(InternalSettingsPlugin.class, TimeSeriesAnalyticsPlugin.class); } @Override diff --git a/src/test/java/org/opensearch/forecast/model/ForecastSerializationTests.java-e b/src/test/java/org/opensearch/forecast/model/ForecastSerializationTests.java-e new file mode 100644 index 000000000..e83fce6d9 --- /dev/null +++ b/src/test/java/org/opensearch/forecast/model/ForecastSerializationTests.java-e @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.model; + +import java.io.IOException; +import java.util.Collection; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.plugins.Plugin; +import org.opensearch.test.InternalSettingsPlugin; +import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +public class ForecastSerializationTests extends OpenSearchSingleNodeTestCase { + @Override + protected Collection> getPlugins() { + return pluginList(InternalSettingsPlugin.class, TimeSeriesAnalyticsPlugin.class); + } + + @Override + protected NamedWriteableRegistry writableRegistry() { + return getInstanceFromNode(NamedWriteableRegistry.class); + } + + public void testStreamConstructor() throws IOException { + Forecaster forecaster = TestHelpers.randomForecaster(); + + BytesStreamOutput output = new BytesStreamOutput(); + + forecaster.writeTo(output); + NamedWriteableAwareStreamInput streamInput = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + Forecaster parsedForecaster = new Forecaster(streamInput); + assertTrue(parsedForecaster.equals(forecaster)); + } + + public void testStreamConstructorNullUser() throws IOException { + Forecaster forecaster = TestHelpers.ForecasterBuilder.newInstance().setUser(null).build(); + + BytesStreamOutput output = new BytesStreamOutput(); + + forecaster.writeTo(output); + NamedWriteableAwareStreamInput streamInput = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + Forecaster parsedForecaster = new Forecaster(streamInput); + assertTrue(parsedForecaster.equals(forecaster)); + } + + public void testStreamConstructorNullUiMeta() throws IOException { + Forecaster forecaster = TestHelpers.ForecasterBuilder.newInstance().setUiMetadata(null).build(); + + BytesStreamOutput output = new BytesStreamOutput(); + + forecaster.writeTo(output); + NamedWriteableAwareStreamInput streamInput = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + Forecaster parsedForecaster = new Forecaster(streamInput); + assertTrue(parsedForecaster.equals(forecaster)); + } + + public void testStreamConstructorNullCustomResult() throws IOException { + Forecaster forecaster = TestHelpers.ForecasterBuilder.newInstance().setCustomResultIndex(null).build(); + + BytesStreamOutput output = new BytesStreamOutput(); + + forecaster.writeTo(output); + NamedWriteableAwareStreamInput streamInput = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + Forecaster parsedForecaster = new Forecaster(streamInput); + assertTrue(parsedForecaster.equals(forecaster)); + } + + public void testStreamConstructorNullImputationOption() throws IOException { + Forecaster forecaster = TestHelpers.ForecasterBuilder.newInstance().setNullImputationOption().build(); + + BytesStreamOutput output = new BytesStreamOutput(); + + forecaster.writeTo(output); + NamedWriteableAwareStreamInput streamInput = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); + Forecaster parsedForecaster = new Forecaster(streamInput); + assertTrue(parsedForecaster.equals(forecaster)); + } +} diff --git a/src/test/java/org/opensearch/forecast/model/ForecasterTests.java-e b/src/test/java/org/opensearch/forecast/model/ForecasterTests.java-e new file mode 100644 index 000000000..0b64912bf --- /dev/null +++ b/src/test/java/org/opensearch/forecast/model/ForecasterTests.java-e @@ -0,0 +1,396 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.model; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.Matchers.is; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.hamcrest.MatcherAssert; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.dataprocessor.ImputationOption; +import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.TimeConfiguration; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; + +public class ForecasterTests extends AbstractTimeSeriesTest { + TimeConfiguration forecastInterval = new IntervalTimeConfiguration(1, ChronoUnit.MINUTES); + TimeConfiguration windowDelay = new IntervalTimeConfiguration(1, ChronoUnit.MINUTES); + String forecasterId = "testId"; + Long version = 1L; + String name = "testName"; + String description = "testDescription"; + String timeField = "testTimeField"; + List indices = Collections.singletonList("testIndex"); + List features = Collections.emptyList(); // Assuming no features for simplicity + MatchAllQueryBuilder filterQuery = QueryBuilders.matchAllQuery(); + Integer shingleSize = 1; + Map uiMetadata = new HashMap<>(); + Integer schemaVersion = 1; + Instant lastUpdateTime = Instant.now(); + List categoryFields = Arrays.asList("field1", "field2"); + User user = new User("testUser", Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); + String resultIndex = null; + Integer horizon = 1; + + public void testForecasterConstructor() { + ImputationOption imputationOption = TestHelpers.randomImputationOption(); + + Forecaster forecaster = new Forecaster( + forecasterId, + version, + name, + description, + timeField, + indices, + features, + filterQuery, + forecastInterval, + windowDelay, + shingleSize, + uiMetadata, + schemaVersion, + lastUpdateTime, + categoryFields, + user, + resultIndex, + horizon, + imputationOption + ); + + assertEquals(forecasterId, forecaster.getId()); + assertEquals(version, forecaster.getVersion()); + assertEquals(name, forecaster.getName()); + assertEquals(description, forecaster.getDescription()); + assertEquals(timeField, forecaster.getTimeField()); + assertEquals(indices, forecaster.getIndices()); + assertEquals(features, forecaster.getFeatureAttributes()); + assertEquals(filterQuery, forecaster.getFilterQuery()); + assertEquals(forecastInterval, forecaster.getInterval()); + assertEquals(windowDelay, forecaster.getWindowDelay()); + assertEquals(shingleSize, forecaster.getShingleSize()); + assertEquals(uiMetadata, forecaster.getUiMetadata()); + assertEquals(schemaVersion, forecaster.getSchemaVersion()); + assertEquals(lastUpdateTime, forecaster.getLastUpdateTime()); + assertEquals(categoryFields, forecaster.getCategoryFields()); + assertEquals(user, forecaster.getUser()); + assertEquals(resultIndex, forecaster.getCustomResultIndex()); + assertEquals(horizon, forecaster.getHorizon()); + assertEquals(imputationOption, forecaster.getImputationOption()); + } + + public void testForecasterConstructorWithNullForecastInterval() { + TimeConfiguration forecastInterval = null; + + ValidationException ex = expectThrows(ValidationException.class, () -> { + new Forecaster( + forecasterId, + version, + name, + description, + timeField, + indices, + features, + filterQuery, + forecastInterval, + windowDelay, + shingleSize, + uiMetadata, + schemaVersion, + lastUpdateTime, + categoryFields, + user, + resultIndex, + horizon, + TestHelpers.randomImputationOption() + ); + }); + + MatcherAssert.assertThat(ex.getMessage(), containsString(ForecastCommonMessages.NULL_FORECAST_INTERVAL)); + MatcherAssert.assertThat(ex.getType(), is(ValidationIssueType.FORECAST_INTERVAL)); + MatcherAssert.assertThat(ex.getAspect(), is(ValidationAspect.FORECASTER)); + } + + public void testNegativeInterval() { + var forecastInterval = new IntervalTimeConfiguration(0, ChronoUnit.MINUTES); // An interval less than or equal to zero + + ValidationException ex = expectThrows(ValidationException.class, () -> { + new Forecaster( + forecasterId, + version, + name, + description, + timeField, + indices, + features, + filterQuery, + forecastInterval, + windowDelay, + shingleSize, + uiMetadata, + schemaVersion, + lastUpdateTime, + categoryFields, + user, + resultIndex, + horizon, + TestHelpers.randomImputationOption() + ); + }); + + MatcherAssert.assertThat(ex.getMessage(), containsString(ForecastCommonMessages.INVALID_FORECAST_INTERVAL)); + MatcherAssert.assertThat(ex.getType(), is(ValidationIssueType.FORECAST_INTERVAL)); + MatcherAssert.assertThat(ex.getAspect(), is(ValidationAspect.FORECASTER)); + } + + public void testMaxCategoryFieldsLimits() { + List categoryFields = Arrays.asList("field1", "field2", "field3"); + + ValidationException ex = expectThrows(ValidationException.class, () -> { + new Forecaster( + forecasterId, + version, + name, + description, + timeField, + indices, + features, + filterQuery, + forecastInterval, + windowDelay, + shingleSize, + uiMetadata, + schemaVersion, + lastUpdateTime, + categoryFields, + user, + resultIndex, + horizon, + TestHelpers.randomImputationOption() + ); + }); + + MatcherAssert.assertThat(ex.getMessage(), containsString(CommonMessages.getTooManyCategoricalFieldErr(2))); + MatcherAssert.assertThat(ex.getType(), is(ValidationIssueType.CATEGORY)); + MatcherAssert.assertThat(ex.getAspect(), is(ValidationAspect.FORECASTER)); + } + + public void testBlankName() { + String name = ""; + + ValidationException ex = expectThrows(ValidationException.class, () -> { + new Forecaster( + forecasterId, + version, + name, + description, + timeField, + indices, + features, + filterQuery, + forecastInterval, + windowDelay, + shingleSize, + uiMetadata, + schemaVersion, + lastUpdateTime, + categoryFields, + user, + resultIndex, + horizon, + TestHelpers.randomImputationOption() + ); + }); + + MatcherAssert.assertThat(ex.getMessage(), containsString(CommonMessages.EMPTY_NAME)); + MatcherAssert.assertThat(ex.getType(), is(ValidationIssueType.NAME)); + MatcherAssert.assertThat(ex.getAspect(), is(ValidationAspect.FORECASTER)); + } + + public void testInvalidCustomResultIndex() { + String resultIndex = "test"; + + ValidationException ex = expectThrows(ValidationException.class, () -> { + new Forecaster( + forecasterId, + version, + name, + description, + timeField, + indices, + features, + filterQuery, + forecastInterval, + windowDelay, + shingleSize, + uiMetadata, + schemaVersion, + lastUpdateTime, + categoryFields, + user, + resultIndex, + horizon, + TestHelpers.randomImputationOption() + ); + }); + + MatcherAssert.assertThat(ex.getMessage(), containsString(ForecastCommonMessages.INVALID_RESULT_INDEX_PREFIX)); + MatcherAssert.assertThat(ex.getType(), is(ValidationIssueType.RESULT_INDEX)); + MatcherAssert.assertThat(ex.getAspect(), is(ValidationAspect.FORECASTER)); + } + + public void testValidCustomResultIndex() { + String resultIndex = ForecastCommonName.CUSTOM_RESULT_INDEX_PREFIX + "test"; + + var forecaster = new Forecaster( + forecasterId, + version, + name, + description, + timeField, + indices, + features, + filterQuery, + forecastInterval, + windowDelay, + shingleSize, + uiMetadata, + schemaVersion, + lastUpdateTime, + categoryFields, + user, + resultIndex, + horizon, + TestHelpers.randomImputationOption() + ); + + assertEquals(resultIndex, forecaster.getCustomResultIndex()); + } + + public void testInvalidHorizon() { + int horizon = 0; + + ValidationException ex = expectThrows(ValidationException.class, () -> { + new Forecaster( + forecasterId, + version, + name, + description, + timeField, + indices, + features, + filterQuery, + forecastInterval, + windowDelay, + shingleSize, + uiMetadata, + schemaVersion, + lastUpdateTime, + categoryFields, + user, + resultIndex, + horizon, + TestHelpers.randomImputationOption() + ); + }); + + MatcherAssert.assertThat(ex.getMessage(), containsString("Horizon size must be a positive integer no larger than")); + MatcherAssert.assertThat(ex.getType(), is(ValidationIssueType.SHINGLE_SIZE_FIELD)); + MatcherAssert.assertThat(ex.getAspect(), is(ValidationAspect.FORECASTER)); + } + + public void testParse() throws IOException { + Forecaster forecaster = TestHelpers.randomForecaster(); + String forecasterString = TestHelpers + .xContentBuilderToString(forecaster.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + LOG.info(forecasterString); + Forecaster parsedForecaster = Forecaster.parse(TestHelpers.parser(forecasterString)); + assertEquals("Parsing forecaster doesn't work", forecaster, parsedForecaster); + } + + public void testParseEmptyMetaData() throws IOException { + Forecaster forecaster = TestHelpers.ForecasterBuilder.newInstance().setUiMetadata(null).build(); + String forecasterString = TestHelpers + .xContentBuilderToString(forecaster.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + LOG.info(forecasterString); + Forecaster parsedForecaster = Forecaster.parse(TestHelpers.parser(forecasterString)); + assertEquals("Parsing forecaster doesn't work", forecaster, parsedForecaster); + } + + public void testParseNullLastUpdateTime() throws IOException { + Forecaster forecaster = TestHelpers.ForecasterBuilder.newInstance().setLastUpdateTime(null).build(); + String forecasterString = TestHelpers + .xContentBuilderToString(forecaster.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + LOG.info(forecasterString); + Forecaster parsedForecaster = Forecaster.parse(TestHelpers.parser(forecasterString)); + assertEquals("Parsing forecaster doesn't work", forecaster, parsedForecaster); + } + + public void testParseNullCategoryFields() throws IOException { + Forecaster forecaster = TestHelpers.ForecasterBuilder.newInstance().setCategoryFields(null).build(); + String forecasterString = TestHelpers + .xContentBuilderToString(forecaster.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + LOG.info(forecasterString); + Forecaster parsedForecaster = Forecaster.parse(TestHelpers.parser(forecasterString)); + assertEquals("Parsing forecaster doesn't work", forecaster, parsedForecaster); + } + + public void testParseNullUser() throws IOException { + Forecaster forecaster = TestHelpers.ForecasterBuilder.newInstance().setUser(null).build(); + String forecasterString = TestHelpers + .xContentBuilderToString(forecaster.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + LOG.info(forecasterString); + Forecaster parsedForecaster = Forecaster.parse(TestHelpers.parser(forecasterString)); + assertEquals("Parsing forecaster doesn't work", forecaster, parsedForecaster); + } + + public void testParseNullCustomResultIndex() throws IOException { + Forecaster forecaster = TestHelpers.ForecasterBuilder.newInstance().setCustomResultIndex(null).build(); + String forecasterString = TestHelpers + .xContentBuilderToString(forecaster.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + LOG.info(forecasterString); + Forecaster parsedForecaster = Forecaster.parse(TestHelpers.parser(forecasterString)); + assertEquals("Parsing forecaster doesn't work", forecaster, parsedForecaster); + } + + public void testParseNullImpute() throws IOException { + Forecaster forecaster = TestHelpers.ForecasterBuilder.newInstance().setNullImputationOption().build(); + String forecasterString = TestHelpers + .xContentBuilderToString(forecaster.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); + LOG.info(forecasterString); + Forecaster parsedForecaster = Forecaster.parse(TestHelpers.parser(forecasterString)); + assertEquals("Parsing forecaster doesn't work", forecaster, parsedForecaster); + } + + public void testGetImputer() throws IOException { + Forecaster forecaster = TestHelpers.randomForecaster(); + assertTrue(null != forecaster.getImputer()); + } + + public void testGetImputerNullImputer() throws IOException { + Forecaster forecaster = TestHelpers.ForecasterBuilder.newInstance().setNullImputationOption().build(); + assertTrue(null != forecaster.getImputer()); + } +} diff --git a/src/test/java/org/opensearch/forecast/settings/ForecastEnabledSettingTests.java-e b/src/test/java/org/opensearch/forecast/settings/ForecastEnabledSettingTests.java-e new file mode 100644 index 000000000..dda3a8761 --- /dev/null +++ b/src/test/java/org/opensearch/forecast/settings/ForecastEnabledSettingTests.java-e @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.settings; + +import org.opensearch.test.OpenSearchTestCase; + +public class ForecastEnabledSettingTests extends OpenSearchTestCase { + + public void testIsForecastEnabled() { + assertTrue(ForecastEnabledSetting.isForecastEnabled()); + ForecastEnabledSetting.getInstance().setSettingValue(ForecastEnabledSetting.FORECAST_ENABLED, false); + assertTrue(!ForecastEnabledSetting.isForecastEnabled()); + } + + public void testIsForecastBreakerEnabled() { + assertTrue(ForecastEnabledSetting.isForecastBreakerEnabled()); + ForecastEnabledSetting.getInstance().setSettingValue(ForecastEnabledSetting.FORECAST_BREAKER_ENABLED, false); + assertTrue(!ForecastEnabledSetting.isForecastBreakerEnabled()); + } + + public void testIsDoorKeeperInCacheEnabled() { + assertTrue(!ForecastEnabledSetting.isDoorKeeperInCacheEnabled()); + ForecastEnabledSetting.getInstance().setSettingValue(ForecastEnabledSetting.FORECAST_DOOR_KEEPER_IN_CACHE_ENABLED, true); + assertTrue(ForecastEnabledSetting.isDoorKeeperInCacheEnabled()); + } + +} diff --git a/src/test/java/org/opensearch/forecast/settings/ForecastNumericSettingTests.java-e b/src/test/java/org/opensearch/forecast/settings/ForecastNumericSettingTests.java-e new file mode 100644 index 000000000..80b2202bf --- /dev/null +++ b/src/test/java/org/opensearch/forecast/settings/ForecastNumericSettingTests.java-e @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.settings; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.opensearch.common.settings.Setting; +import org.opensearch.test.OpenSearchTestCase; + +public class ForecastNumericSettingTests extends OpenSearchTestCase { + private ForecastNumericSetting forecastSetting; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + forecastSetting = ForecastNumericSetting.getInstance(); + } + + public void testMaxCategoricalFields() { + forecastSetting.setSettingValue(ForecastNumericSetting.CATEGORY_FIELD_LIMIT, 3); + int value = ForecastNumericSetting.maxCategoricalFields(); + assertEquals("Expected value is 3", 3, value); + } + + public void testGetSettingValue() { + Map> settingsMap = new HashMap<>(); + Setting testSetting = Setting.intSetting("test.setting", 1, Setting.Property.NodeScope); + settingsMap.put("test.setting", testSetting); + forecastSetting = new ForecastNumericSetting(settingsMap); + + forecastSetting.setSettingValue("test.setting", 2); + Integer value = forecastSetting.getSettingValue("test.setting"); + assertEquals("Expected value is 2", 2, value.intValue()); + } + + public void testGetSettingNonexistentKey() { + try { + forecastSetting.getSettingValue("nonexistent.key"); + fail("Expected an IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + assertEquals("Cannot find setting by key [nonexistent.key]", e.getMessage()); + } + } +} diff --git a/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java-e b/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java-e new file mode 100644 index 000000000..7d9f9b1b2 --- /dev/null +++ b/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java-e @@ -0,0 +1,275 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.search.aggregations.metrics; + +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.temporal.ChronoUnit; +import java.util.*; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.AbstractProfileRunnerTests; +import org.opensearch.ad.AnomalyDetectorProfileRunner; +import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.transport.ProfileAction; +import org.opensearch.ad.transport.ProfileNodeResponse; +import org.opensearch.ad.transport.ProfileResponse; +import org.opensearch.ad.util.SecurityClientUtil; +import org.opensearch.cluster.ClusterName; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.BigArrays; +import org.opensearch.search.aggregations.InternalAggregation; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; + +import com.carrotsearch.hppc.BitMixer; + +/** + * Run tests in ES package since InternalCardinality has only package private constructors + * and we cannot mock it since it is a final class. + * + */ +public class CardinalityProfileTests extends AbstractProfileRunnerTests { + enum ADResultStatus { + NO_RESULT, + EXCEPTION + } + + enum CardinalityStatus { + EXCEPTION, + NORMAL + } + + @SuppressWarnings("unchecked") + private void setUpMultiEntityClientGet(DetectorStatus detectorStatus, JobStatus jobStatus, ErrorResultStatus errorResultStatus) + throws IOException { + detector = TestHelpers + .randomAnomalyDetectorWithInterval(new IntervalTimeConfiguration(detectorIntervalMin, ChronoUnit.MINUTES), true); + NodeStateManager nodeStateManager = mock(NodeStateManager.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(1); + listener.onResponse(Optional.of(detector)); + return null; + }).when(nodeStateManager).getAnomalyDetector(anyString(), any(ActionListener.class)); + clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); + runner = new AnomalyDetectorProfileRunner( + client, + clientUtil, + xContentRegistry(), + nodeFilter, + requiredSamples, + transportService, + adTaskManager + ); + + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + GetRequest request = (GetRequest) args[0]; + ActionListener listener = (ActionListener) args[1]; + + if (request.index().equals(CommonName.CONFIG_INDEX)) { + switch (detectorStatus) { + case EXIST: + listener.onResponse(TestHelpers.createGetResponse(detector, detector.getId(), CommonName.CONFIG_INDEX)); + break; + default: + assertTrue("should not reach here", false); + break; + } + } else if (request.index().equals(CommonName.JOB_INDEX)) { + AnomalyDetectorJob job = null; + switch (jobStatus) { + case ENABLED: + job = TestHelpers.randomAnomalyDetectorJob(true); + listener.onResponse(TestHelpers.createGetResponse(job, detector.getId(), CommonName.JOB_INDEX)); + break; + default: + assertTrue("should not reach here", false); + break; + } + } else if (request.index().equals(ADCommonName.DETECTION_STATE_INDEX)) { + switch (errorResultStatus) { + case NO_ERROR: + listener.onResponse(null); + break; + case NULL_POINTER_EXCEPTION: + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(true); + doThrow(NullPointerException.class).when(response).getSourceAsString(); + listener.onResponse(response); + break; + default: + assertTrue("should not reach here", false); + break; + } + } + return null; + }).when(client).get(any(), any()); + } + + @SuppressWarnings("unchecked") + private void setUpMultiEntityClientSearch(ADResultStatus resultStatus, CardinalityStatus cardinalityStatus) { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + ActionListener listener = (ActionListener) args[1]; + SearchRequest request = (SearchRequest) args[0]; + if (request.indices()[0].equals(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS)) { + switch (resultStatus) { + case NO_RESULT: + SearchResponse mockResponse = mock(SearchResponse.class); + when(mockResponse.getHits()).thenReturn(TestHelpers.createSearchHits(0)); + listener.onResponse(mockResponse); + break; + case EXCEPTION: + listener.onFailure(new RuntimeException()); + break; + default: + assertTrue("should not reach here", false); + break; + } + } else { + switch (cardinalityStatus) { + case EXCEPTION: + listener.onFailure(new RuntimeException()); + break; + case NORMAL: + SearchResponse response = mock(SearchResponse.class); + List aggs = new ArrayList<>(1); + HyperLogLogPlusPlus hyperLogLog = new HyperLogLogPlusPlus( + AbstractHyperLogLog.MIN_PRECISION, + BigArrays.NON_RECYCLING_INSTANCE, + 0 + ); + for (int i = 0; i < 100; i++) { + hyperLogLog.collect(0, BitMixer.mix64(randomIntBetween(1, 100))); + } + aggs.add(new InternalCardinality(ADCommonName.TOTAL_ENTITIES, hyperLogLog, new HashMap<>())); + when(response.getAggregations()).thenReturn(InternalAggregations.from(aggs)); + listener.onResponse(response); + break; + default: + assertTrue("should not reach here", false); + break; + } + + } + + return null; + }).when(client).search(any(), any()); + } + + @SuppressWarnings("unchecked") + private void setUpProfileAction() { + doAnswer(invocation -> { + Object[] args = invocation.getArguments(); + + ActionListener listener = (ActionListener) args[2]; + + ProfileNodeResponse profileNodeResponse1 = new ProfileNodeResponse( + discoveryNode1, + new HashMap<>(), + shingleSize, + 0, + 0, + new ArrayList<>(), + 0 + ); + List profileNodeResponses = Arrays.asList(profileNodeResponse1); + listener.onResponse(new ProfileResponse(new ClusterName(clusterName), profileNodeResponses, Collections.emptyList())); + + return null; + }).when(client).execute(eq(ProfileAction.INSTANCE), any(), any()); + } + + public void testFailGetEntityStats() throws IOException, InterruptedException { + setUpMultiEntityClientGet(DetectorStatus.EXIST, JobStatus.ENABLED, ErrorResultStatus.NO_ERROR); + setUpMultiEntityClientSearch(ADResultStatus.NO_RESULT, CardinalityStatus.EXCEPTION); + setUpProfileAction(); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + runner.profile(detector.getId(), ActionListener.wrap(response -> { + assertTrue("Should not reach here ", false); + inProgressLatch.countDown(); + }, exception -> { + assertTrue(exception instanceof RuntimeException); + // this means we don't exit with failImmediately. failImmediately can make we return early when there are other concurrent + // requests + assertTrue(exception.getMessage(), exception.getMessage().contains("Exceptions:")); + inProgressLatch.countDown(); + + }), totalInitProgress); + + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + public void testNoResultsNoError() throws IOException, InterruptedException { + setUpMultiEntityClientGet(DetectorStatus.EXIST, JobStatus.ENABLED, ErrorResultStatus.NO_ERROR); + setUpMultiEntityClientSearch(ADResultStatus.NO_RESULT, CardinalityStatus.NORMAL); + setUpProfileAction(); + + final AtomicInteger called = new AtomicInteger(0); + + runner.profile(detector.getId(), ActionListener.wrap(response -> { + assertTrue(response.getInitProgress() != null); + called.getAndIncrement(); + }, exception -> { + assertTrue("Should not reach here ", false); + called.getAndIncrement(); + }), totalInitProgress); + + while (called.get() == 0) { + Thread.sleep(100); + } + // should only call onResponse once + assertEquals(1, called.get()); + } + + public void testFailConfirmInitted() throws IOException, InterruptedException { + setUpMultiEntityClientGet(DetectorStatus.EXIST, JobStatus.ENABLED, ErrorResultStatus.NO_ERROR); + setUpMultiEntityClientSearch(ADResultStatus.EXCEPTION, CardinalityStatus.NORMAL); + setUpProfileAction(); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + runner.profile(detector.getId(), ActionListener.wrap(response -> { + assertTrue("Should not reach here ", false); + inProgressLatch.countDown(); + }, exception -> { + assertTrue(exception instanceof RuntimeException); + // this means we don't exit with failImmediately. failImmediately can make we return early when there are other concurrent + // requests + assertTrue(exception.getMessage(), exception.getMessage().contains("Exceptions:")); + inProgressLatch.countDown(); + + }), totalInitProgress); + + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } +} diff --git a/src/test/java/org/opensearch/timeseries/AbstractTimeSeriesTest.java b/src/test/java/org/opensearch/timeseries/AbstractTimeSeriesTest.java index c3f5c161b..8799b9be6 100644 --- a/src/test/java/org/opensearch/timeseries/AbstractTimeSeriesTest.java +++ b/src/test/java/org/opensearch/timeseries/AbstractTimeSeriesTest.java @@ -43,23 +43,22 @@ import org.opensearch.Version; import org.opensearch.action.ActionResponse; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.ad.AnomalyDetectorPlugin; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.model.DetectorInternalState; import org.opensearch.cluster.metadata.AliasMetadata; import org.opensearch.cluster.metadata.IndexMetadata; -import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.logging.Loggers; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.http.HttpRequest; import org.opensearch.http.HttpResponse; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestRequest.Method; -import org.opensearch.rest.RestStatus; import org.opensearch.search.SearchModule; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.FixedExecutorBuilder; @@ -261,10 +260,10 @@ protected static void setUpThreadPool(String name) { name, new FixedExecutorBuilder( Settings.EMPTY, - AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, 1, 1000, - "opensearch.ad." + AnomalyDetectorPlugin.AD_THREAD_POOL_NAME + "opensearch.ad." + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME ) ); } @@ -446,7 +445,7 @@ protected IndexMetadata indexMeta(String name, long creationDate, String... alia protected void setUpADThreadPool(ThreadPool mockThreadPool) { ExecutorService executorService = mock(ExecutorService.class); - when(mockThreadPool.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); + when(mockThreadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); doAnswer(invocation -> { Runnable runnable = invocation.getArgument(0); runnable.run(); diff --git a/src/test/java/org/opensearch/timeseries/AbstractTimeSeriesTest.java-e b/src/test/java/org/opensearch/timeseries/AbstractTimeSeriesTest.java-e new file mode 100644 index 000000000..46b0043da --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/AbstractTimeSeriesTest.java-e @@ -0,0 +1,455 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries; + +import static org.hamcrest.Matchers.containsString; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.apache.commons.lang3.exception.ExceptionUtils; +import org.apache.logging.log4j.Level; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.core.LogEvent; +import org.apache.logging.log4j.core.Logger; +import org.apache.logging.log4j.core.appender.AbstractAppender; +import org.apache.logging.log4j.core.config.Property; +import org.apache.logging.log4j.core.layout.PatternLayout; +import org.apache.logging.log4j.util.StackLocatorUtil; +import org.opensearch.Version; +import org.opensearch.action.ActionResponse; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.model.DetectorInternalState; +import org.opensearch.cluster.metadata.AliasMetadata; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.common.logging.Loggers; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.http.HttpRequest; +import org.opensearch.http.HttpResponse; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestRequest.Method; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.search.SearchModule; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.FixedExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportInterceptor; +import org.opensearch.transport.TransportService; + +import test.org.opensearch.ad.util.FakeNode; + +public class AbstractTimeSeriesTest extends OpenSearchTestCase { + + protected static final Logger LOG = (Logger) LogManager.getLogger(AbstractTimeSeriesTest.class); + + // transport test node + protected int nodesCount; + protected FakeNode[] testNodes; + + /** + * Log4j appender that uses a list to store log messages + * + */ + protected class TestAppender extends AbstractAppender { + private static final String EXCEPTION_CLASS = "exception_class"; + private static final String EXCEPTION_MSG = "exception_message"; + private static final String EXCEPTION_STACK_TRACE = "stacktrace"; + + Map, Map> exceptions; + // whether record exception and its stack trace or not. + // If you log(msg, exception), by default we won't record exception and its stack trace. + boolean recordExceptions; + + protected TestAppender(String name) { + this(name, false); + } + + protected TestAppender(String name, boolean recordExceptions) { + super(name, null, PatternLayout.createDefaultLayout(), true, Property.EMPTY_ARRAY); + this.recordExceptions = recordExceptions; + if (recordExceptions) { + exceptions = new HashMap, Map>(); + } + } + + public List messages = new ArrayList(); + + public boolean containsMessage(String msg, boolean formatString) { + Pattern p = null; + if (formatString) { + String regex = convertToRegex(msg); + p = Pattern.compile(regex); + } + for (String logMsg : messages) { + LOG.info(logMsg); + if (p != null) { + Matcher m = p.matcher(logMsg); + if (m.matches()) { + return true; + } + } else if (logMsg.contains(msg)) { + return true; + } + } + return false; + } + + public boolean containsMessage(String msg) { + return containsMessage(msg, false); + } + + public int countMessage(String msg, boolean formatString) { + Pattern p = null; + if (formatString) { + String regex = convertToRegex(msg); + p = Pattern.compile(regex); + } + int count = 0; + for (String logMsg : messages) { + LOG.info(logMsg); + if (p != null) { + Matcher m = p.matcher(logMsg); + if (m.matches()) { + count++; + } + } else if (logMsg.contains(msg)) { + count++; + } + } + return count; + } + + public int countMessage(String msg) { + return countMessage(msg, false); + } + + public Boolean containExceptionClass(Class throwable, String className) { + Map throwableInformation = exceptions.get(throwable); + return Optional.ofNullable(throwableInformation).map(m -> m.get(EXCEPTION_CLASS)).map(s -> s.equals(className)).orElse(false); + } + + public Boolean containExceptionMsg(Class throwable, String msg) { + Map throwableInformation = exceptions.get(throwable); + return Optional + .ofNullable(throwableInformation) + .map(m -> m.get(EXCEPTION_MSG)) + .map(s -> ((String) s).contains(msg)) + .orElse(false); + } + + public Boolean containExceptionTrace(Class throwable, String traceElement) { + Map throwableInformation = exceptions.get(throwable); + return Optional + .ofNullable(throwableInformation) + .map(m -> m.get(EXCEPTION_STACK_TRACE)) + .map(s -> ((String) s).contains(traceElement)) + .orElse(false); + } + + @Override + public void append(LogEvent event) { + messages.add(event.getMessage().getFormattedMessage()); + if (recordExceptions && event.getThrown() != null) { + Map throwableInformation = new HashMap(); + final Throwable throwable = event.getThrown(); + if (throwable.getClass().getCanonicalName() != null) { + throwableInformation.put(EXCEPTION_CLASS, throwable.getClass().getCanonicalName()); + } + if (throwable.getMessage() != null) { + throwableInformation.put(EXCEPTION_MSG, throwable.getMessage()); + } + if (throwable.getMessage() != null) { + StringBuilder stackTrace = new StringBuilder(ExceptionUtils.getStackTrace(throwable)); + throwableInformation.put(EXCEPTION_STACK_TRACE, stackTrace.toString()); + } + exceptions.put(throwable.getClass(), throwableInformation); + } + } + + /** + * Convert a string with format like "Cannot save %s due to write block." + * to a regex with .* like "Cannot save .* due to write block." + * @return converted regex + */ + private String convertToRegex(String formattedStr) { + int percentIndex = formattedStr.indexOf("%"); + return formattedStr.substring(0, percentIndex) + ".*" + formattedStr.substring(percentIndex + 2); + } + + } + + protected static ThreadPool threadPool; + + protected TestAppender testAppender; + + protected Logger logger; + + /** + * Set up test with junit that a warning was logged with log4j + */ + protected void setUpLog4jForJUnit(Class cls, boolean recordExceptions) { + String loggerName = toLoggerName(callerClass(cls)); + logger = (Logger) LogManager.getLogger(loggerName); + Loggers.setLevel(logger, Level.DEBUG); + testAppender = new TestAppender(loggerName, recordExceptions); + testAppender.start(); + logger.addAppender(testAppender); + } + + protected void setUpLog4jForJUnit(Class cls) { + setUpLog4jForJUnit(cls, false); + } + + private static String toLoggerName(final Class cls) { + String canonicalName = cls.getCanonicalName(); + return canonicalName != null ? canonicalName : cls.getName(); + } + + private static Class callerClass(final Class clazz) { + if (clazz != null) { + return clazz; + } + final Class candidate = StackLocatorUtil.getCallerClass(3); + if (candidate == null) { + throw new UnsupportedOperationException("No class provided, and an appropriate one cannot be found."); + } + return candidate; + } + + /** + * remove the appender + */ + protected void tearDownLog4jForJUnit() { + logger.removeAppender(testAppender); + testAppender.stop(); + } + + protected static void setUpThreadPool(String name) { + threadPool = new TestThreadPool( + name, + new FixedExecutorBuilder( + Settings.EMPTY, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + 1, + 1000, + "opensearch.ad." + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME + ) + ); + } + + protected static void tearDownThreadPool() { + LOG.info("tear down threadPool"); + assertTrue(ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS)); + threadPool = null; + } + + /** + * + * @param transportInterceptor Interceptor to for transport requests. Used + * to mock transport layer. + * @param nodeSettings node override of setting + * @param setting the supported setting set. + */ + public void setupTestNodes(TransportInterceptor transportInterceptor, final Settings nodeSettings, Setting... setting) { + setupTestNodes(transportInterceptor, randomIntBetween(2, 10), nodeSettings, Version.CURRENT, setting); + } + + /** + * + * @param transportInterceptor Interceptor to for transport requests. Used + * to mock transport layer. + * @param numberOfNodes number of nodes in the cluster + * @param nodeSettings node override of setting + * @param setting the supported setting set. + */ + public void setupTestNodes( + TransportInterceptor transportInterceptor, + int numberOfNodes, + final Settings nodeSettings, + Version version, + Setting... setting + ) { + nodesCount = numberOfNodes; + testNodes = new FakeNode[nodesCount]; + Set> settingSet = new HashSet<>(Arrays.asList(setting)); + for (int i = 0; i < testNodes.length; i++) { + testNodes[i] = new FakeNode("node" + i, threadPool, nodeSettings, settingSet, transportInterceptor, version); + } + FakeNode.connectNodes(testNodes); + } + + public void setupTestNodes(Setting... setting) { + setupTestNodes(TransportService.NOOP_TRANSPORT_INTERCEPTOR, Settings.EMPTY, setting); + } + + public void setupTestNodes(Settings nodeSettings) { + setupTestNodes(TransportService.NOOP_TRANSPORT_INTERCEPTOR, nodeSettings); + } + + public void setupTestNodes(TransportInterceptor transportInterceptor) { + setupTestNodes(transportInterceptor, Settings.EMPTY); + } + + public void tearDownTestNodes() { + if (testNodes == null) { + return; + } + for (FakeNode testNode : testNodes) { + testNode.close(); + } + testNodes = null; + } + + public void assertException( + PlainActionFuture listener, + Class exceptionType, + String msg + ) { + Exception e = expectThrows(exceptionType, () -> listener.actionGet(20_000)); + assertThat("actual message: " + e.getMessage(), e.getMessage(), containsString(msg)); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, Collections.emptyList()); + List entries = searchModule.getNamedXContents(); + entries + .addAll( + Arrays + .asList( + AnomalyDetector.XCONTENT_REGISTRY, + AnomalyResult.XCONTENT_REGISTRY, + DetectorInternalState.XCONTENT_REGISTRY, + AnomalyDetectorJob.XCONTENT_REGISTRY + ) + ); + return new NamedXContentRegistry(entries); + } + + protected RestRequest createRestRequest(Method method) { + return RestRequest.request(xContentRegistry(), new HttpRequest() { + + @Override + public Method method() { + return method; + } + + @Override + public String uri() { + return "/"; + } + + @Override + public BytesReference content() { + // TODO Auto-generated method stub + return null; + } + + @Override + public Map> getHeaders() { + return new HashMap<>(); + } + + @Override + public List strictCookies() { + // TODO Auto-generated method stub + return null; + } + + @Override + public HttpVersion protocolVersion() { + return HttpRequest.HttpVersion.HTTP_1_1; + } + + @Override + public HttpRequest removeHeader(String header) { + // TODO Auto-generated method stub + return null; + } + + @Override + public HttpResponse createResponse(RestStatus status, BytesReference content) { + // TODO Auto-generated method stub + return null; + } + + @Override + public Exception getInboundException() { + // TODO Auto-generated method stub + return null; + } + + @Override + public void release() { + // TODO Auto-generated method stub + + } + + @Override + public HttpRequest releaseAndCopy() { + // TODO Auto-generated method stub + return null; + } + + }, null); + } + + protected IndexMetadata indexMeta(String name, long creationDate, String... aliases) { + IndexMetadata.Builder builder = IndexMetadata + .builder(name) + .settings( + Settings + .builder() + .put("index.number_of_shards", 1) + .put("index.number_of_replicas", 1) + .put("index.version.created", Version.CURRENT.id) + ); + builder.creationDate(creationDate); + for (String alias : aliases) { + builder.putAlias(AliasMetadata.builder(alias).build()); + } + return builder.build(); + } + + protected void setUpADThreadPool(ThreadPool mockThreadPool) { + ExecutorService executorService = mock(ExecutorService.class); + + when(mockThreadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); + doAnswer(invocation -> { + Runnable runnable = invocation.getArgument(0); + runnable.run(); + return null; + }).when(executorService).execute(any(Runnable.class)); + } +} diff --git a/src/test/java/org/opensearch/timeseries/DataByFeatureIdTests.java b/src/test/java/org/opensearch/timeseries/DataByFeatureIdTests.java index 1cf60dd89..631ba99e3 100644 --- a/src/test/java/org/opensearch/timeseries/DataByFeatureIdTests.java +++ b/src/test/java/org/opensearch/timeseries/DataByFeatureIdTests.java @@ -9,8 +9,8 @@ import org.junit.Before; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; diff --git a/src/test/java/org/opensearch/timeseries/DataByFeatureIdTests.java-e b/src/test/java/org/opensearch/timeseries/DataByFeatureIdTests.java-e new file mode 100644 index 000000000..8387b38a6 --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/DataByFeatureIdTests.java-e @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries; + +import java.io.IOException; + +import org.junit.Before; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.model.DataByFeatureId; + +public class DataByFeatureIdTests extends OpenSearchTestCase { + String expectedFeatureId = "testFeature"; + Double expectedData = 123.45; + DataByFeatureId dataByFeatureId; + + @Before + public void setup() { + dataByFeatureId = new DataByFeatureId(expectedFeatureId, expectedData); + } + + public void testInputOutputStream() throws IOException { + + BytesStreamOutput output = new BytesStreamOutput(); + dataByFeatureId.writeTo(output); + StreamInput streamInput = output.bytes().streamInput(); + DataByFeatureId restoredDataByFeatureId = new DataByFeatureId(streamInput); + assertEquals(expectedFeatureId, restoredDataByFeatureId.getFeatureId()); + assertEquals(expectedData, restoredDataByFeatureId.getData()); + } + + public void testToXContent() throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder(); + + dataByFeatureId.toXContent(builder, ToXContent.EMPTY_PARAMS); + + XContentParser parser = createParser(builder); + // advance to first token + XContentParser.Token token = parser.nextToken(); + if (token != XContentParser.Token.START_OBJECT) { + throw new IOException("Expected data to start with an Object"); + } + + DataByFeatureId parsedDataByFeatureId = DataByFeatureId.parse(parser); + + assertEquals(expectedFeatureId, parsedDataByFeatureId.getFeatureId()); + assertEquals(expectedData, parsedDataByFeatureId.getData()); + } + + public void testEqualsAndHashCode() { + DataByFeatureId dataByFeatureId1 = new DataByFeatureId("feature1", 1.0); + DataByFeatureId dataByFeatureId2 = new DataByFeatureId("feature1", 1.0); + DataByFeatureId dataByFeatureId3 = new DataByFeatureId("feature2", 2.0); + + // Test equal objects are equal + assertEquals(dataByFeatureId1, dataByFeatureId2); + assertEquals(dataByFeatureId1.hashCode(), dataByFeatureId2.hashCode()); + + // Test unequal objects are not equal + assertNotEquals(dataByFeatureId1, dataByFeatureId3); + assertNotEquals(dataByFeatureId1.hashCode(), dataByFeatureId3.hashCode()); + + // Test object is not equal to null + assertNotEquals(dataByFeatureId1, null); + + // Test object is not equal to object of different type + assertNotEquals(dataByFeatureId1, "string"); + + // Test object is equal to itself + assertEquals(dataByFeatureId1, dataByFeatureId1); + assertEquals(dataByFeatureId1.hashCode(), dataByFeatureId1.hashCode()); + } +} diff --git a/src/test/java/org/opensearch/timeseries/TestHelpers.java b/src/test/java/org/opensearch/timeseries/TestHelpers.java index 609e079e6..33ceb54fa 100644 --- a/src/test/java/org/opensearch/timeseries/TestHelpers.java +++ b/src/test/java/org/opensearch/timeseries/TestHelpers.java @@ -13,7 +13,7 @@ import static org.apache.hc.core5.http.ContentType.APPLICATION_JSON; import static org.opensearch.cluster.node.DiscoveryNodeRole.BUILT_IN_ROLES; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder; import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; import static org.opensearch.test.OpenSearchTestCase.*; @@ -98,8 +98,6 @@ import org.opensearch.common.Priority; import org.opensearch.common.Randomness; import org.opensearch.common.UUIDs; -import org.opensearch.common.bytes.BytesArray; -import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; @@ -108,6 +106,9 @@ import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; @@ -119,7 +120,6 @@ import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; -import org.opensearch.rest.RestStatus; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchModule; diff --git a/src/test/java/org/opensearch/timeseries/TestHelpers.java-e b/src/test/java/org/opensearch/timeseries/TestHelpers.java-e new file mode 100644 index 000000000..b38b1c8d0 --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/TestHelpers.java-e @@ -0,0 +1,1768 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries; + +import static org.apache.hc.core5.http.ContentType.APPLICATION_JSON; +import static org.opensearch.cluster.node.DiscoveryNodeRole.BUILT_IN_ROLES; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder; +import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; +import static org.opensearch.test.OpenSearchTestCase.*; +import static org.powermock.api.mockito.PowerMockito.mock; +import static org.powermock.api.mockito.PowerMockito.when; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.Callable; +import java.util.function.Consumer; +import java.util.stream.DoubleStream; +import java.util.stream.IntStream; + +import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.Header; +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.io.entity.StringEntity; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.util.Strings; +import org.apache.lucene.search.TotalHits; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.admin.cluster.health.ClusterHealthResponse; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsResponse; +import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsResponse.FieldMappingMetadata; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.feature.Features; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.mock.model.MockSimpleLog; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskState; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyDetectorExecutionInput; +import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.model.AnomalyResultBucket; +import org.opensearch.ad.model.DetectorInternalState; +import org.opensearch.ad.model.DetectorValidationIssue; +import org.opensearch.ad.model.ExpectedValueList; +import org.opensearch.ad.ratelimit.RequestPriority; +import org.opensearch.ad.ratelimit.ResultWriteRequest; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.Request; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.Response; +import org.opensearch.client.RestClient; +import org.opensearch.client.WarningsHandler; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.block.ClusterBlocks; +import org.opensearch.cluster.metadata.AliasMetadata; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.CheckedConsumer; +import org.opensearch.common.Priority; +import org.opensearch.common.Randomness; +import org.opensearch.common.UUIDs; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.index.get.GetResult; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.SearchModule; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.search.profile.SearchProfileShardResults; +import org.opensearch.search.suggest.Suggest; +import org.opensearch.test.ClusterServiceUtils; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.OpenSearchRestTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.dataprocessor.ImputationMethod; +import org.opensearch.timeseries.dataprocessor.ImputationOption; +import org.opensearch.timeseries.model.DataByFeatureId; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.TimeConfiguration; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +public class TestHelpers { + + public static final String LEGACY_OPENDISTRO_AD_BASE_DETECTORS_URI = "/_opendistro/_anomaly_detection/detectors"; + public static final String AD_BASE_DETECTORS_URI = "/_plugins/_anomaly_detection/detectors"; + public static final String AD_BASE_RESULT_URI = AD_BASE_DETECTORS_URI + "/results"; + public static final String AD_BASE_PREVIEW_URI = AD_BASE_DETECTORS_URI + "/%s/_preview"; + public static final String AD_BASE_STATS_URI = "/_plugins/_anomaly_detection/stats"; + public static ImmutableSet HISTORICAL_ANALYSIS_RUNNING_STATS = ImmutableSet + .of(ADTaskState.CREATED.name(), ADTaskState.INIT.name(), ADTaskState.RUNNING.name()); + // Task may fail if memory circuit breaker triggered. + public static final Set HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS = ImmutableSet + .of(ADTaskState.FINISHED.name(), ADTaskState.FAILED.name()); + public static ImmutableSet HISTORICAL_ANALYSIS_DONE_STATS = ImmutableSet + .of(ADTaskState.FAILED.name(), ADTaskState.FINISHED.name(), ADTaskState.STOPPED.name()); + private static final Logger logger = LogManager.getLogger(TestHelpers.class); + public static final Random random = new Random(42); + + public static Response makeRequest( + RestClient client, + String method, + String endpoint, + Map params, + String jsonEntity, + List
headers + ) throws IOException { + HttpEntity httpEntity = Strings.isBlank(jsonEntity) ? null : new StringEntity(jsonEntity, ContentType.APPLICATION_JSON); + return makeRequest(client, method, endpoint, params, httpEntity, headers); + } + + public static Response makeRequest( + RestClient client, + String method, + String endpoint, + Map params, + HttpEntity entity, + List
headers + ) throws IOException { + return makeRequest(client, method, endpoint, params, entity, headers, false); + } + + public static Response makeRequest( + RestClient client, + String method, + String endpoint, + Map params, + HttpEntity entity, + List
headers, + boolean strictDeprecationMode + ) throws IOException { + Request request = new Request(method, endpoint); + + RequestOptions.Builder options = RequestOptions.DEFAULT.toBuilder(); + if (headers != null) { + headers.forEach(header -> options.addHeader(header.getName(), header.getValue())); + } + options.setWarningsHandler(strictDeprecationMode ? WarningsHandler.STRICT : WarningsHandler.PERMISSIVE); + request.setOptions(options.build()); + + if (params != null) { + params.entrySet().forEach(it -> request.addParameter(it.getKey(), it.getValue())); + } + if (entity != null) { + request.setEntity(entity); + } + return client.performRequest(request); + } + + public static String xContentBuilderToString(XContentBuilder builder) { + return BytesReference.bytes(builder).utf8ToString(); + } + + public static XContentBuilder builder() throws IOException { + return XContentBuilder.builder(XContentType.JSON.xContent()); + } + + public static XContentParser parser(String xc) throws IOException { + return parser(xc, true); + } + + public static XContentParser parser(String xc, boolean skipFirstToken) throws IOException { + XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry(), LoggingDeprecationHandler.INSTANCE, xc); + if (skipFirstToken) { + parser.nextToken(); + } + return parser; + } + + public static Map XContentBuilderToMap(XContentBuilder builder) { + return XContentHelper.convertToMap(BytesReference.bytes(builder), false, builder.contentType()).v2(); + } + + public static NamedXContentRegistry xContentRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, Collections.emptyList()); + return new NamedXContentRegistry(searchModule.getNamedXContents()); + } + + public static AnomalyDetector randomAnomalyDetector(Map uiMetadata, Instant lastUpdateTime) throws IOException { + return randomAnomalyDetector(ImmutableList.of(randomFeature()), uiMetadata, lastUpdateTime, null); + } + + public static AnomalyDetector randomAnomalyDetector(Map uiMetadata, Instant lastUpdateTime, boolean featureEnabled) + throws IOException { + return randomAnomalyDetector(ImmutableList.of(randomFeature(featureEnabled)), uiMetadata, lastUpdateTime, null); + } + + public static AnomalyDetector randomAnomalyDetector(List features, Map uiMetadata, Instant lastUpdateTime) + throws IOException { + return randomAnomalyDetector(features, uiMetadata, lastUpdateTime, null); + } + + public static AnomalyDetector randomAnomalyDetector( + List features, + Map uiMetadata, + Instant lastUpdateTime, + String detectorType + ) throws IOException { + return randomAnomalyDetector(features, uiMetadata, lastUpdateTime, true, null); + } + + public static AnomalyDetector randomAnomalyDetector( + List features, + Map uiMetadata, + Instant lastUpdateTime, + boolean withUser, + List categoryFields + ) throws IOException { + return randomAnomalyDetector( + ImmutableList.of(randomAlphaOfLength(10).toLowerCase(Locale.ROOT)), + features, + uiMetadata, + lastUpdateTime, + OpenSearchRestTestCase.randomLongBetween(1, 1000), + withUser, + categoryFields + ); + } + + public static AnomalyDetector randomAnomalyDetector( + List indices, + List features, + Map uiMetadata, + Instant lastUpdateTime, + long detectionIntervalInMinutes, + boolean withUser, + List categoryFields + ) throws IOException { + User user = withUser ? randomUser() : null; + return new AnomalyDetector( + randomAlphaOfLength(10), + randomLong(), + randomAlphaOfLength(20), + randomAlphaOfLength(30), + randomAlphaOfLength(5), + indices, + features, + randomQuery(), + new IntervalTimeConfiguration(detectionIntervalInMinutes, ChronoUnit.MINUTES), + randomIntervalTimeConfiguration(), + // our test's heap allowance is very small (20 MB heap usage would cause OOM) + // reduce size to not cause issue. + randomIntBetween(1, 20), + uiMetadata, + randomInt(), + lastUpdateTime, + categoryFields, + user, + null, + TestHelpers.randomImputationOption() + ); + } + + public static AnomalyDetector randomDetector(List features, String indexName, int detectionIntervalInMinutes, String timeField) + throws IOException { + return randomDetector(features, indexName, detectionIntervalInMinutes, timeField, null); + } + + public static AnomalyDetector randomDetector( + List features, + String indexName, + int detectionIntervalInMinutes, + String timeField, + List categoryFields + ) throws IOException { + return randomDetector(features, indexName, detectionIntervalInMinutes, timeField, categoryFields, null); + } + + public static AnomalyDetector randomDetector( + List features, + String indexName, + int detectionIntervalInMinutes, + String timeField, + List categoryFields, + String resultIndex + ) throws IOException { + return new AnomalyDetector( + randomAlphaOfLength(10), + randomLong(), + randomAlphaOfLength(20), + randomAlphaOfLength(30), + timeField, + ImmutableList.of(indexName), + features, + randomQuery("{\"bool\":{\"filter\":[{\"exists\":{\"field\":\"value\"}}]}}"), + new IntervalTimeConfiguration(detectionIntervalInMinutes, ChronoUnit.MINUTES), + new IntervalTimeConfiguration(OpenSearchRestTestCase.randomLongBetween(1, 5), ChronoUnit.MINUTES), + 8, + null, + randomInt(), + Instant.now(), + categoryFields, + null, + resultIndex, + TestHelpers.randomImputationOption() + ); + } + + public static DateRange randomDetectionDateRange() { + return new DateRange( + Instant.now().truncatedTo(ChronoUnit.SECONDS).minus(10, ChronoUnit.DAYS), + Instant.now().truncatedTo(ChronoUnit.SECONDS) + ); + } + + public static AnomalyDetector randomAnomalyDetectorUsingCategoryFields(String detectorId, List categoryFields) + throws IOException { + return randomAnomalyDetectorUsingCategoryFields( + detectorId, + randomAlphaOfLength(5), + ImmutableList.of(randomAlphaOfLength(10).toLowerCase(Locale.ROOT)), + categoryFields + ); + } + + public static AnomalyDetector randomAnomalyDetectorUsingCategoryFields( + String detectorId, + String timeField, + List indices, + List categoryFields + ) throws IOException { + return randomAnomalyDetectorUsingCategoryFields(detectorId, timeField, indices, categoryFields, null); + } + + public static AnomalyDetector randomAnomalyDetectorUsingCategoryFields( + String detectorId, + String timeField, + List indices, + List categoryFields, + String resultIndex + ) throws IOException { + return new AnomalyDetector( + detectorId, + randomLong(), + randomAlphaOfLength(20), + randomAlphaOfLength(30), + timeField, + indices, + ImmutableList.of(randomFeature(true)), + randomQuery(), + randomIntervalTimeConfiguration(), + new IntervalTimeConfiguration(0, ChronoUnit.MINUTES), + randomIntBetween(1, TimeSeriesSettings.MAX_SHINGLE_SIZE), + null, + randomInt(), + Instant.now(), + categoryFields, + randomUser(), + resultIndex, + TestHelpers.randomImputationOption() + ); + } + + public static AnomalyDetector randomAnomalyDetector(List features) throws IOException { + return randomAnomalyDetector(randomAlphaOfLength(5), randomAlphaOfLength(10).toLowerCase(Locale.ROOT), features); + } + + public static AnomalyDetector randomAnomalyDetector(String timefield, String indexName) throws IOException { + return randomAnomalyDetector(timefield, indexName, ImmutableList.of(randomFeature(true))); + } + + public static AnomalyDetector randomAnomalyDetector(String timefield, String indexName, List features) throws IOException { + return new AnomalyDetector( + randomAlphaOfLength(10), + randomLong(), + randomAlphaOfLength(20), + randomAlphaOfLength(30), + timefield, + ImmutableList.of(indexName.toLowerCase(Locale.ROOT)), + features, + randomQuery(), + randomIntervalTimeConfiguration(), + randomIntervalTimeConfiguration(), + randomIntBetween(1, TimeSeriesSettings.MAX_SHINGLE_SIZE), + null, + randomInt(), + Instant.now(), + null, + randomUser(), + null, + TestHelpers.randomImputationOption() + ); + } + + public static AnomalyDetector randomAnomalyDetectorWithEmptyFeature() throws IOException { + return new AnomalyDetector( + randomAlphaOfLength(10), + randomLong(), + randomAlphaOfLength(20), + randomAlphaOfLength(30), + randomAlphaOfLength(5), + ImmutableList.of(randomAlphaOfLength(10).toLowerCase(Locale.ROOT)), + ImmutableList.of(), + randomQuery(), + randomIntervalTimeConfiguration(), + randomIntervalTimeConfiguration(), + randomIntBetween(1, TimeSeriesSettings.MAX_SHINGLE_SIZE), + null, + randomInt(), + Instant.now().truncatedTo(ChronoUnit.SECONDS), + null, + randomUser(), + null, + TestHelpers.randomImputationOption() + ); + } + + public static AnomalyDetector randomAnomalyDetectorWithInterval(TimeConfiguration interval) throws IOException { + return randomAnomalyDetectorWithInterval(interval, false); + } + + public static AnomalyDetector randomAnomalyDetectorWithInterval(TimeConfiguration interval, boolean hcDetector) throws IOException { + List categoryField = hcDetector ? ImmutableList.of(randomAlphaOfLength(5)) : null; + return new AnomalyDetector( + randomAlphaOfLength(10), + randomLong(), + randomAlphaOfLength(20), + randomAlphaOfLength(30), + randomAlphaOfLength(5), + ImmutableList.of(randomAlphaOfLength(10).toLowerCase(Locale.ROOT)), + ImmutableList.of(randomFeature()), + randomQuery(), + interval, + randomIntervalTimeConfiguration(), + randomIntBetween(1, TimeSeriesSettings.MAX_SHINGLE_SIZE), + null, + randomInt(), + Instant.now().truncatedTo(ChronoUnit.SECONDS), + categoryField, + randomUser(), + null, + TestHelpers.randomImputationOption() + ); + } + + public static AnomalyResultBucket randomAnomalyResultBucket() { + Map map = new HashMap<>(); + map.put(randomAlphaOfLength(5), randomAlphaOfLength(5)); + return new AnomalyResultBucket(map, randomInt(), randomDouble()); + } + + public static class AnomalyDetectorBuilder { + private String detectorId = randomAlphaOfLength(10); + private Long version = randomLong(); + private String name = randomAlphaOfLength(20); + private String description = randomAlphaOfLength(30); + private String timeField = randomAlphaOfLength(5); + private List indices = ImmutableList.of(randomAlphaOfLength(10).toLowerCase(Locale.ROOT)); + private List featureAttributes = ImmutableList.of(randomFeature(true)); + private QueryBuilder filterQuery; + private TimeConfiguration detectionInterval = randomIntervalTimeConfiguration(); + private TimeConfiguration windowDelay = randomIntervalTimeConfiguration(); + private Integer shingleSize = randomIntBetween(1, TimeSeriesSettings.MAX_SHINGLE_SIZE); + private Map uiMetadata = null; + private Integer schemaVersion = randomInt(); + private Instant lastUpdateTime = Instant.now().truncatedTo(ChronoUnit.SECONDS); + private List categoryFields = null; + private User user = randomUser(); + private String resultIndex = null; + private ImputationOption imputationOption = null; + + public static AnomalyDetectorBuilder newInstance() throws IOException { + return new AnomalyDetectorBuilder(); + } + + private AnomalyDetectorBuilder() throws IOException { + filterQuery = randomQuery(); + } + + public AnomalyDetectorBuilder setDetectorId(String detectorId) { + this.detectorId = detectorId; + return this; + } + + public AnomalyDetectorBuilder setVersion(Long version) { + this.version = version; + return this; + } + + public AnomalyDetectorBuilder setName(String name) { + this.name = name; + return this; + } + + public AnomalyDetectorBuilder setDescription(String description) { + this.description = description; + return this; + } + + public AnomalyDetectorBuilder setTimeField(String timeField) { + this.timeField = timeField; + return this; + } + + public AnomalyDetectorBuilder setIndices(List indices) { + this.indices = indices; + return this; + } + + public AnomalyDetectorBuilder setFeatureAttributes(List featureAttributes) { + this.featureAttributes = featureAttributes; + return this; + } + + public AnomalyDetectorBuilder setFilterQuery(QueryBuilder filterQuery) { + this.filterQuery = filterQuery; + return this; + } + + public AnomalyDetectorBuilder setDetectionInterval(TimeConfiguration detectionInterval) { + this.detectionInterval = detectionInterval; + return this; + } + + public AnomalyDetectorBuilder setWindowDelay(TimeConfiguration windowDelay) { + this.windowDelay = windowDelay; + return this; + } + + public AnomalyDetectorBuilder setShingleSize(Integer shingleSize) { + this.shingleSize = shingleSize; + return this; + } + + public AnomalyDetectorBuilder setUiMetadata(Map uiMetadata) { + this.uiMetadata = uiMetadata; + return this; + } + + public AnomalyDetectorBuilder setSchemaVersion(Integer schemaVersion) { + this.schemaVersion = schemaVersion; + return this; + } + + public AnomalyDetectorBuilder setLastUpdateTime(Instant lastUpdateTime) { + this.lastUpdateTime = lastUpdateTime; + return this; + } + + public AnomalyDetectorBuilder setCategoryFields(List categoryFields) { + this.categoryFields = categoryFields; + return this; + } + + public AnomalyDetectorBuilder setUser(User user) { + this.user = user; + return this; + } + + public AnomalyDetectorBuilder setResultIndex(String resultIndex) { + this.resultIndex = resultIndex; + return this; + } + + public AnomalyDetectorBuilder setImputationOption(ImputationMethod method, Optional defaultFill, boolean integerSentive) { + this.imputationOption = new ImputationOption(method, defaultFill, integerSentive); + return this; + } + + public AnomalyDetector build() { + return new AnomalyDetector( + detectorId, + version, + name, + description, + timeField, + indices, + featureAttributes, + filterQuery, + detectionInterval, + windowDelay, + shingleSize, + uiMetadata, + schemaVersion, + lastUpdateTime, + categoryFields, + user, + resultIndex, + imputationOption + ); + } + } + + public static AnomalyDetector randomAnomalyDetectorWithInterval(TimeConfiguration interval, boolean hcDetector, boolean featureEnabled) + throws IOException { + List categoryField = hcDetector ? ImmutableList.of(randomAlphaOfLength(5)) : null; + return new AnomalyDetector( + randomAlphaOfLength(10), + randomLong(), + randomAlphaOfLength(20), + randomAlphaOfLength(30), + randomAlphaOfLength(5), + ImmutableList.of(randomAlphaOfLength(10).toLowerCase(Locale.ROOT)), + ImmutableList.of(randomFeature(featureEnabled)), + randomQuery(), + interval, + randomIntervalTimeConfiguration(), + randomIntBetween(1, TimeSeriesSettings.MAX_SHINGLE_SIZE), + null, + randomInt(), + Instant.now().truncatedTo(ChronoUnit.SECONDS), + categoryField, + randomUser(), + null, + TestHelpers.randomImputationOption() + ); + } + + public static SearchSourceBuilder randomFeatureQuery() throws IOException { + String query = "{\"query\":{\"match\":{\"user\":{\"query\":\"kimchy\",\"operator\":\"OR\",\"prefix_length\":0," + + "\"max_expansions\":50,\"fuzzy_transpositions\":true,\"lenient\":false,\"zero_terms_query\":\"NONE\"," + + "\"auto_generate_synonyms_phrase_query\":true,\"boost\":1}}}}"; + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + SearchModule searchModule = new SearchModule(Settings.EMPTY, Collections.emptyList()); + XContentParser parser = XContentType.JSON + .xContent() + .createParser(new NamedXContentRegistry(searchModule.getNamedXContents()), LoggingDeprecationHandler.INSTANCE, query); + searchSourceBuilder.parseXContent(parser); + return searchSourceBuilder; + } + + public static QueryBuilder randomQuery() throws IOException { + String query = "{\"bool\":{\"must\":{\"term\":{\"user\":\"kimchy\"}},\"filter\":{\"term\":{\"tag\":" + + "\"tech\"}},\"must_not\":{\"range\":{\"age\":{\"gte\":10,\"lte\":20}}},\"should\":[{\"term\":" + + "{\"tag\":\"wow\"}},{\"term\":{\"tag\":\"elasticsearch\"}}],\"minimum_should_match\":1,\"boost\":1}}"; + return randomQuery(query); + } + + public static QueryBuilder randomQuery(String query) throws IOException { + XContentParser parser = TestHelpers.parser(query); + return parseInnerQueryBuilder(parser); + } + + public static AggregationBuilder randomAggregation() throws IOException { + return randomAggregation(randomAlphaOfLength(5)); + } + + public static AggregationBuilder randomAggregation(String aggregationName) throws IOException { + XContentParser parser = parser("{\"" + aggregationName + "\":{\"value_count\":{\"field\":\"ok\"}}}"); + + AggregatorFactories.Builder parsed = AggregatorFactories.parseAggregators(parser); + return parsed.getAggregatorFactories().iterator().next(); + } + + /** + * Parse string aggregation query into {@link AggregationBuilder} + * Sample input: + * "{\"test\":{\"value_count\":{\"field\":\"ok\"}}}" + * + * @param aggregationQuery aggregation builder + * @return aggregation builder + * @throws IOException IO exception + */ + public static AggregationBuilder parseAggregation(String aggregationQuery) throws IOException { + XContentParser parser = parser(aggregationQuery); + + AggregatorFactories.Builder parsed = AggregatorFactories.parseAggregators(parser); + return parsed.getAggregatorFactories().iterator().next(); + } + + public static Map randomUiMetadata() { + return ImmutableMap.of(randomAlphaOfLength(5), randomFeature()); + } + + public static TimeConfiguration randomIntervalTimeConfiguration() { + return new IntervalTimeConfiguration(OpenSearchRestTestCase.randomLongBetween(1, 1000), ChronoUnit.MINUTES); + } + + public static IntervalSchedule randomIntervalSchedule() { + return new IntervalSchedule( + Instant.now().truncatedTo(ChronoUnit.SECONDS), + OpenSearchRestTestCase.randomIntBetween(1, 1000), + ChronoUnit.MINUTES + ); + } + + public static Feature randomFeature() { + return randomFeature(randomAlphaOfLength(5), randomAlphaOfLength(5)); + } + + public static Feature randomFeature(String featureName, String aggregationName) { + return randomFeature(featureName, aggregationName, randomBoolean()); + } + + public static Feature randomFeature(boolean enabled) { + return randomFeature(randomAlphaOfLength(5), randomAlphaOfLength(5), enabled); + } + + public static Feature randomFeature(String featureName, String aggregationName, boolean enabled) { + AggregationBuilder testAggregation = null; + try { + testAggregation = randomAggregation(aggregationName); + } catch (IOException e) { + logger.error("Fail to generate test aggregation"); + throw new RuntimeException(); + } + return new Feature(randomAlphaOfLength(5), featureName, enabled, testAggregation); + } + + public static Feature randomFeature(String featureName, String fieldName, String aggregationMethod, boolean enabled) + throws IOException { + XContentParser parser = parser("{\"" + featureName + "\":{\"" + aggregationMethod + "\":{\"field\":\"" + fieldName + "\"}}}"); + AggregatorFactories.Builder aggregators = AggregatorFactories.parseAggregators(parser); + AggregationBuilder testAggregation = aggregators.getAggregatorFactories().iterator().next(); + return new Feature(randomAlphaOfLength(5), featureName, enabled, testAggregation); + } + + public static Features randomFeatures() { + List> ranges = Arrays.asList(new AbstractMap.SimpleEntry<>(0L, 1L)); + double[][] unprocessed = new double[][] { { randomDouble(), randomDouble() } }; + double[][] processed = new double[][] { { randomDouble(), randomDouble() } }; + + return new Features(ranges, unprocessed, processed); + } + + public static List randomThresholdingResults() { + double grade = 1.; + double confidence = 0.5; + double score = 1.; + + ThresholdingResult thresholdingResult = new ThresholdingResult(grade, confidence, score); + List results = new ArrayList<>(); + results.add(thresholdingResult); + return results; + } + + public static User randomUser() { + return new User( + randomAlphaOfLength(8), + ImmutableList.of(randomAlphaOfLength(10)), + ImmutableList.of("all_access"), + ImmutableList.of("attribute=test") + ); + } + + public static void assertFailWith(Class clazz, Callable callable) throws Exception { + assertFailWith(clazz, null, callable); + } + + public static void assertFailWith(Class clazz, String message, Callable callable) throws Exception { + try { + callable.call(); + } catch (Throwable e) { + if (e.getClass() != clazz) { + throw e; + } + if (message != null && !e.getMessage().contains(message)) { + throw e; + } + } + } + + public static FeatureData randomFeatureData() { + return new FeatureData(randomAlphaOfLength(5), randomAlphaOfLength(5), randomDouble()); + } + + public static AnomalyResult randomAnomalyDetectResult() { + return randomAnomalyDetectResult(randomDouble(), randomAlphaOfLength(5), null); + } + + public static AnomalyResult randomAnomalyDetectResult(double score) { + return randomAnomalyDetectResult(randomDouble(), null, null); + } + + public static AnomalyResult randomAnomalyDetectResult(String error) { + return randomAnomalyDetectResult(Double.NaN, error, null); + } + + public static AnomalyResult randomAnomalyDetectResult(double score, String error, String taskId) { + return randomAnomalyDetectResult(score, error, taskId, true); + } + + public static AnomalyResult randomAnomalyDetectResult(double score, String error, String taskId, boolean withUser) { + User user = withUser ? randomUser() : null; + List relavantAttribution = new ArrayList(); + relavantAttribution.add(new DataByFeatureId(randomAlphaOfLength(5), randomDoubleBetween(0, 1.0, true))); + relavantAttribution.add(new DataByFeatureId(randomAlphaOfLength(5), randomDoubleBetween(0, 1.0, true))); + + List pastValues = new ArrayList(); + pastValues.add(new DataByFeatureId(randomAlphaOfLength(5), randomDouble())); + pastValues.add(new DataByFeatureId(randomAlphaOfLength(5), randomDouble())); + + List expectedValuesList = new ArrayList(); + List expectedValues = new ArrayList(); + expectedValues.add(new DataByFeatureId(randomAlphaOfLength(5), randomDouble())); + expectedValues.add(new DataByFeatureId(randomAlphaOfLength(5), randomDouble())); + expectedValuesList.add(new ExpectedValueList(randomDoubleBetween(0, 1.0, true), expectedValues)); + + return new AnomalyResult( + randomAlphaOfLength(5), + taskId, + score, + randomDouble(), + randomDouble(), + ImmutableList.of(randomFeatureData(), randomFeatureData()), + Instant.now().truncatedTo(ChronoUnit.SECONDS), + Instant.now().truncatedTo(ChronoUnit.SECONDS), + Instant.now().truncatedTo(ChronoUnit.SECONDS), + Instant.now().truncatedTo(ChronoUnit.SECONDS), + error, + Optional.empty(), + user, + CommonValue.NO_SCHEMA_VERSION, + null, + Instant.now().truncatedTo(ChronoUnit.SECONDS), + relavantAttribution, + pastValues, + expectedValuesList, + randomDoubleBetween(1.1, 10.0, true) + ); + } + + public static AnomalyResult randomHCADAnomalyDetectResult(double score, double grade) { + return randomHCADAnomalyDetectResult(score, grade, null); + } + + public static ResultWriteRequest randomResultWriteRequest(String detectorId, double score, double grade) { + ResultWriteRequest resultWriteRequest = new ResultWriteRequest( + Instant.now().plus(10, ChronoUnit.MINUTES).toEpochMilli(), + detectorId, + RequestPriority.MEDIUM, + randomHCADAnomalyDetectResult(score, grade), + null + ); + return resultWriteRequest; + } + + public static AnomalyResult randomHCADAnomalyDetectResult(double score, double grade, String error) { + return randomHCADAnomalyDetectResult(null, null, score, grade, error, null, null); + } + + public static AnomalyResult randomHCADAnomalyDetectResult( + String detectorId, + String taskId, + double score, + double grade, + String error, + Long startTimeEpochMillis, + Long endTimeEpochMillis + ) { + return randomHCADAnomalyDetectResult(detectorId, taskId, null, score, grade, error, startTimeEpochMillis, endTimeEpochMillis); + } + + public static AnomalyResult randomHCADAnomalyDetectResult( + String detectorId, + String taskId, + Map entityAttrs, + double score, + double grade, + String error, + Long startTimeEpochMillis, + Long endTimeEpochMillis + ) { + List relavantAttribution = new ArrayList(); + relavantAttribution.add(new DataByFeatureId(randomAlphaOfLength(5), randomDoubleBetween(0, 1.0, true))); + relavantAttribution.add(new DataByFeatureId(randomAlphaOfLength(5), randomDoubleBetween(0, 1.0, true))); + + List pastValues = new ArrayList(); + pastValues.add(new DataByFeatureId(randomAlphaOfLength(5), randomDouble())); + pastValues.add(new DataByFeatureId(randomAlphaOfLength(5), randomDouble())); + + List expectedValuesList = new ArrayList(); + List expectedValues = new ArrayList(); + expectedValues.add(new DataByFeatureId(randomAlphaOfLength(5), randomDouble())); + expectedValues.add(new DataByFeatureId(randomAlphaOfLength(5), randomDouble())); + expectedValuesList.add(new ExpectedValueList(randomDoubleBetween(0, 1.0, true), expectedValues)); + + return new AnomalyResult( + detectorId == null ? randomAlphaOfLength(5) : detectorId, + taskId, + score, + grade, + randomDouble(), + ImmutableList.of(randomFeatureData(), randomFeatureData()), + startTimeEpochMillis == null ? Instant.now().truncatedTo(ChronoUnit.SECONDS) : Instant.ofEpochMilli(startTimeEpochMillis), + endTimeEpochMillis == null ? Instant.now().truncatedTo(ChronoUnit.SECONDS) : Instant.ofEpochMilli(endTimeEpochMillis), + startTimeEpochMillis == null ? Instant.now().truncatedTo(ChronoUnit.SECONDS) : Instant.ofEpochMilli(startTimeEpochMillis), + endTimeEpochMillis == null ? Instant.now().truncatedTo(ChronoUnit.SECONDS) : Instant.ofEpochMilli(endTimeEpochMillis), + error, + entityAttrs == null + ? Optional.ofNullable(Entity.createSingleAttributeEntity(randomAlphaOfLength(5), randomAlphaOfLength(5))) + : Optional.ofNullable(Entity.createEntityByReordering(entityAttrs)), + randomUser(), + CommonValue.NO_SCHEMA_VERSION, + null, + Instant.now().truncatedTo(ChronoUnit.SECONDS), + relavantAttribution, + pastValues, + expectedValuesList, + randomDoubleBetween(1.1, 10.0, true) + ); + } + + public static AnomalyDetectorJob randomAnomalyDetectorJob() { + return randomAnomalyDetectorJob(true); + } + + public static AnomalyDetectorJob randomAnomalyDetectorJob(boolean enabled, Instant enabledTime, Instant disabledTime) { + return new AnomalyDetectorJob( + randomAlphaOfLength(10), + randomIntervalSchedule(), + randomIntervalTimeConfiguration(), + enabled, + enabledTime, + disabledTime, + Instant.now().truncatedTo(ChronoUnit.SECONDS), + 60L, + randomUser(), + null + ); + } + + public static AnomalyDetectorJob randomAnomalyDetectorJob(boolean enabled) { + return randomAnomalyDetectorJob( + enabled, + Instant.now().truncatedTo(ChronoUnit.SECONDS), + Instant.now().truncatedTo(ChronoUnit.SECONDS) + ); + } + + public static AnomalyDetectorExecutionInput randomAnomalyDetectorExecutionInput() throws IOException { + return new AnomalyDetectorExecutionInput( + randomAlphaOfLength(5), + Instant.now().minus(10, ChronoUnit.MINUTES).truncatedTo(ChronoUnit.SECONDS), + Instant.now().truncatedTo(ChronoUnit.SECONDS), + randomAnomalyDetector(null, Instant.now().truncatedTo(ChronoUnit.SECONDS)) + ); + } + + public static ActionListener createActionListener( + CheckedConsumer consumer, + Consumer failureConsumer + ) { + return ActionListener.wrap(consumer, failureConsumer); + } + + public static void waitForIndexCreationToComplete(Client client, final String indexName) { + ClusterHealthResponse clusterHealthResponse = client + .admin() + .cluster() + .prepareHealth(indexName) + .setWaitForEvents(Priority.URGENT) + .get(); + logger.info("Status of " + indexName + ": " + clusterHealthResponse.getStatus()); + } + + public static ClusterService createClusterService(ThreadPool threadPool, ClusterSettings clusterSettings) { + DiscoveryNode discoveryNode = new DiscoveryNode( + "node", + OpenSearchRestTestCase.buildNewFakeTransportAddress(), + Collections.emptyMap(), + BUILT_IN_ROLES, + Version.CURRENT + ); + return ClusterServiceUtils.createClusterService(threadPool, discoveryNode, clusterSettings); + } + + public static ClusterState createIndexBlockedState(String indexName, Settings hackedSettings, String alias) { + ClusterState blockedClusterState = null; + IndexMetadata.Builder builder = IndexMetadata.builder(indexName); + if (alias != null) { + builder.putAlias(AliasMetadata.builder(alias)); + } + IndexMetadata indexMetaData = builder + .settings( + Settings + .builder() + .put(IndexMetadata.SETTING_INDEX_UUID, UUIDs.randomBase64UUID()) + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT) + .put(hackedSettings) + ) + .build(); + Metadata metaData = Metadata.builder().put(indexMetaData, false).build(); + blockedClusterState = ClusterState + .builder(new ClusterName("test cluster")) + .metadata(metaData) + .blocks(ClusterBlocks.builder().addBlocks(indexMetaData)) + .build(); + return blockedClusterState; + } + + public static ThreadContext createThreadContext() { + Settings build = Settings.builder().put("request.headers.default", "1").build(); + ThreadContext context = new ThreadContext(build); + context.putHeader("foo", "bar"); + context.putTransient("x", 1); + return context; + } + + public static ThreadPool createThreadPool() { + ThreadPool pool = mock(ThreadPool.class); + when(pool.getThreadContext()).thenReturn(createThreadContext()); + return pool; + } + + public static CreateIndexResponse createIndex(AdminClient adminClient, String indexName, String indexMapping) { + CreateIndexRequest request = new CreateIndexRequest(indexName).mapping(indexMapping); + return adminClient.indices().create(request).actionGet(5_000); + } + + public static void createIndex(RestClient client, String indexName, HttpEntity data) throws IOException { + TestHelpers + .makeRequest( + client, + "POST", + "/" + indexName + "/_doc/" + randomAlphaOfLength(5) + "?refresh=true", + ImmutableMap.of(), + data, + null + ); + } + + public static void createIndexWithTimeField(RestClient client, String indexName, String timeField) throws IOException { + StringBuilder indexMappings = new StringBuilder(); + indexMappings.append("{\"properties\":{"); + indexMappings.append("\"" + timeField + "\":{\"type\":\"date\"}"); + indexMappings.append("}}"); + createIndex(client, indexName.toLowerCase(Locale.ROOT), TestHelpers.toHttpEntity("{\"name\": \"test\"}")); + createIndexMapping(client, indexName.toLowerCase(Locale.ROOT), TestHelpers.toHttpEntity(indexMappings.toString())); + } + + public static void createEmptyIndexWithTimeField(RestClient client, String indexName, String timeField) throws IOException { + StringBuilder indexMappings = new StringBuilder(); + indexMappings.append("{\"properties\":{"); + indexMappings.append("\"" + timeField + "\":{\"type\":\"date\"}"); + indexMappings.append("}}"); + createEmptyIndex(client, indexName.toLowerCase(Locale.ROOT)); + createIndexMapping(client, indexName.toLowerCase(Locale.ROOT), TestHelpers.toHttpEntity(indexMappings.toString())); + } + + public static void createIndexWithHCADFields(RestClient client, String indexName, Map categoryFieldsAndTypes) + throws IOException { + StringBuilder indexMappings = new StringBuilder(); + indexMappings.append("{\"properties\":{"); + for (Map.Entry entry : categoryFieldsAndTypes.entrySet()) { + indexMappings.append("\"" + entry.getKey() + "\":{\"type\":\"" + entry.getValue() + "\"},"); + } + indexMappings.append("\"timestamp\":{\"type\":\"date\"}"); + indexMappings.append("}}"); + createEmptyIndex(client, indexName); + createIndexMapping(client, indexName, TestHelpers.toHttpEntity(indexMappings.toString())); + } + + public static void createEmptyIndexMapping(RestClient client, String indexName, Map fieldsAndTypes) throws IOException { + StringBuilder indexMappings = new StringBuilder(); + indexMappings.append("{\"properties\":{"); + for (Map.Entry entry : fieldsAndTypes.entrySet()) { + indexMappings.append("\"" + entry.getKey() + "\":{\"type\":\"" + entry.getValue() + "\"},"); + } + indexMappings.append("}}"); + createEmptyIndex(client, indexName); + createIndexMapping(client, indexName, TestHelpers.toHttpEntity(indexMappings.toString())); + } + + public static void createEmptyAnomalyResultIndex(RestClient client) throws IOException { + createEmptyIndex(client, ADCommonName.ANOMALY_RESULT_INDEX_ALIAS); + createIndexMapping(client, ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, toHttpEntity(ADIndexManagement.getResultMappings())); + } + + public static void createEmptyIndex(RestClient client, String indexName) throws IOException { + TestHelpers.makeRequest(client, "PUT", "/" + indexName, ImmutableMap.of(), "", null); + } + + public static void createIndexMapping(RestClient client, String indexName, HttpEntity mappings) throws IOException { + TestHelpers.makeRequest(client, "POST", "/" + indexName + "/_mapping", ImmutableMap.of(), mappings, null); + } + + public static void ingestDataToIndex(RestClient client, String indexName, HttpEntity data) throws IOException { + TestHelpers + .makeRequest( + client, + "POST", + "/" + indexName + "/_doc/" + randomAlphaOfLength(5) + "?refresh=true", + ImmutableMap.of(), + data, + null + ); + } + + public static GetResponse createGetResponse(ToXContentObject o, String id, String indexName) throws IOException { + XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + return new GetResponse( + new GetResult( + indexName, + id, + UNASSIGNED_SEQ_NO, + 0, + -1, + true, + BytesReference.bytes(content), + Collections.emptyMap(), + Collections.emptyMap() + ) + ); + } + + public static GetResponse createBrokenGetResponse(String id, String indexName) throws IOException { + ByteBuffer[] buffers = new ByteBuffer[0]; + return new GetResponse( + new GetResult( + indexName, + id, + UNASSIGNED_SEQ_NO, + 0, + -1, + true, + BytesReference.fromByteBuffers(buffers), + Collections.emptyMap(), + Collections.emptyMap() + ) + ); + } + + public static SearchResponse createSearchResponse(ToXContentObject o) throws IOException { + XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + + SearchHit[] hits = new SearchHit[1]; + hits[0] = new SearchHit(0).sourceRef(BytesReference.bytes(content)); + + return new SearchResponse( + new InternalSearchResponse( + new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0f), + InternalAggregations.EMPTY, + new Suggest(Collections.emptyList()), + new SearchProfileShardResults(Collections.emptyMap()), + false, + false, + 1 + ), + "", + 5, + 5, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + } + + public static SearchResponse createEmptySearchResponse() throws IOException { + return new SearchResponse( + new InternalSearchResponse( + new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 1.0f), + InternalAggregations.EMPTY, + new Suggest(Collections.emptyList()), + new SearchProfileShardResults(Collections.emptyMap()), + false, + false, + 1 + ), + "", + 5, + 5, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + } + + public static DetectorInternalState randomDetectState(String error) { + return randomDetectState(error, Instant.now()); + } + + public static DetectorInternalState randomDetectState(Instant lastUpdateTime) { + return randomDetectState(randomAlphaOfLength(5), lastUpdateTime); + } + + public static DetectorInternalState randomDetectState(String error, Instant lastUpdateTime) { + return new DetectorInternalState.Builder().lastUpdateTime(lastUpdateTime).error(error).build(); + } + + public static Map> createFieldMappings( + String index, + String fieldName, + String fieldType + ) throws IOException { + Map> mappings = new HashMap<>(); + FieldMappingMetadata fieldMappingMetadata = new FieldMappingMetadata( + fieldName, + new BytesArray("{\"" + fieldName + "\":{\"type\":\"" + fieldType + "\"}}") + ); + mappings.put(index, Collections.singletonMap(fieldName, fieldMappingMetadata)); + return mappings; + } + + public static ADTask randomAdTask() throws IOException { + return randomAdTask( + randomAlphaOfLength(5), + ADTaskState.RUNNING, + Instant.now().truncatedTo(ChronoUnit.SECONDS), + randomAlphaOfLength(5), + true + ); + } + + public static ADTask randomAdTask(ADTaskType adTaskType) throws IOException { + return randomAdTask( + randomAlphaOfLength(5), + ADTaskState.RUNNING, + Instant.now().truncatedTo(ChronoUnit.SECONDS), + randomAlphaOfLength(5), + true, + adTaskType + ); + } + + public static ADTask randomAdTask( + String taskId, + ADTaskState state, + Instant executionEndTime, + String stoppedBy, + String detectorId, + AnomalyDetector detector, + ADTaskType adTaskType + ) { + executionEndTime = executionEndTime == null ? null : executionEndTime.truncatedTo(ChronoUnit.SECONDS); + Entity entity = null; + if (ADTaskType.HISTORICAL_HC_ENTITY == adTaskType) { + List categoryField = detector.getCategoryFields(); + if (categoryField != null) { + if (categoryField.size() == 1) { + entity = Entity.createSingleAttributeEntity(categoryField.get(0), randomAlphaOfLength(5)); + } else if (categoryField.size() == 2) { + entity = Entity + .createEntityByReordering( + ImmutableMap.of(categoryField.get(0), randomAlphaOfLength(5), categoryField.get(1), randomAlphaOfLength(5)) + ); + } + } + } + ADTask task = ADTask + .builder() + .taskId(taskId) + .taskType(adTaskType.name()) + .detectorId(detectorId) + .detector(detector) + .state(state.name()) + .taskProgress(0.5f) + .initProgress(1.0f) + .currentPiece(Instant.now().truncatedTo(ChronoUnit.SECONDS).minus(randomIntBetween(1, 100), ChronoUnit.MINUTES)) + .executionStartTime(Instant.now().truncatedTo(ChronoUnit.SECONDS).minus(100, ChronoUnit.MINUTES)) + .executionEndTime(executionEndTime) + .isLatest(true) + .error(randomAlphaOfLength(5)) + .checkpointId(randomAlphaOfLength(5)) + .lastUpdateTime(Instant.now().truncatedTo(ChronoUnit.SECONDS)) + .startedBy(randomAlphaOfLength(5)) + .stoppedBy(stoppedBy) + .entity(entity) + .build(); + return task; + } + + public static ADTask randomAdTask(String taskId, ADTaskState state, Instant executionEndTime, String stoppedBy, boolean withDetector) + throws IOException { + return randomAdTask(taskId, state, executionEndTime, stoppedBy, withDetector, ADTaskType.HISTORICAL_SINGLE_ENTITY); + } + + public static ADTask randomAdTask( + String taskId, + ADTaskState state, + Instant executionEndTime, + String stoppedBy, + boolean withDetector, + ADTaskType adTaskType + ) throws IOException { + AnomalyDetector detector = withDetector + ? randomAnomalyDetector(ImmutableMap.of(), Instant.now().truncatedTo(ChronoUnit.SECONDS), true) + : null; + Entity entity = null; + if (withDetector && adTaskType.name().startsWith("HISTORICAL_HC")) { + String categoryField = randomAlphaOfLength(5); + detector = TestHelpers + .randomDetector( + detector.getFeatureAttributes(), + detector.getIndices().get(0), + randomIntBetween(1, 10), + MockSimpleLog.TIME_FIELD, + ImmutableList.of(categoryField) + ); + if (adTaskType.name().equals(ADTaskType.HISTORICAL_HC_ENTITY.name())) { + entity = Entity.createSingleAttributeEntity(categoryField, randomAlphaOfLength(5)); + } + + } + + executionEndTime = executionEndTime == null ? null : executionEndTime.truncatedTo(ChronoUnit.SECONDS); + ADTask task = ADTask + .builder() + .taskId(taskId) + .taskType(adTaskType.name()) + .detectorId(randomAlphaOfLength(5)) + .detector(detector) + .entity(entity) + .state(state.name()) + .taskProgress(0.5f) + .initProgress(1.0f) + .currentPiece(Instant.now().truncatedTo(ChronoUnit.SECONDS).minus(randomIntBetween(1, 100), ChronoUnit.MINUTES)) + .executionStartTime(Instant.now().truncatedTo(ChronoUnit.SECONDS).minus(100, ChronoUnit.MINUTES)) + .executionEndTime(executionEndTime) + .isLatest(true) + .error(randomAlphaOfLength(5)) + .checkpointId(randomAlphaOfLength(5)) + .lastUpdateTime(Instant.now().truncatedTo(ChronoUnit.SECONDS)) + .startedBy(randomAlphaOfLength(5)) + .stoppedBy(stoppedBy) + .lastUpdateTime(Instant.now().truncatedTo(ChronoUnit.SECONDS)) + .build(); + return task; + } + + public static ADTask randomAdTask( + String taskId, + ADTaskState state, + Instant executionEndTime, + String stoppedBy, + AnomalyDetector detector + ) { + executionEndTime = executionEndTime == null ? null : executionEndTime.truncatedTo(ChronoUnit.SECONDS); + Entity entity = null; + if (detector != null) { + if (detector.hasMultipleCategories()) { + Map attrMap = new HashMap<>(); + detector.getCategoryFields().stream().forEach(f -> attrMap.put(f, randomAlphaOfLength(5))); + entity = Entity.createEntityByReordering(attrMap); + } else if (detector.isHighCardinality()) { + entity = Entity.createEntityByReordering(ImmutableMap.of(detector.getCategoryFields().get(0), randomAlphaOfLength(5))); + } + } + String taskType = entity == null ? ADTaskType.HISTORICAL_SINGLE_ENTITY.name() : ADTaskType.HISTORICAL_HC_ENTITY.name(); + ADTask task = ADTask + .builder() + .taskId(taskId) + .taskType(taskType) + .detectorId(randomAlphaOfLength(5)) + .detector(detector) + .state(state.name()) + .taskProgress(0.5f) + .initProgress(1.0f) + .currentPiece(Instant.now().truncatedTo(ChronoUnit.SECONDS).minus(randomIntBetween(1, 100), ChronoUnit.MINUTES)) + .executionStartTime(Instant.now().truncatedTo(ChronoUnit.SECONDS).minus(100, ChronoUnit.MINUTES)) + .executionEndTime(executionEndTime) + .isLatest(true) + .error(randomAlphaOfLength(5)) + .checkpointId(randomAlphaOfLength(5)) + .lastUpdateTime(Instant.now().truncatedTo(ChronoUnit.SECONDS)) + .startedBy(randomAlphaOfLength(5)) + .stoppedBy(stoppedBy) + .lastUpdateTime(Instant.now().truncatedTo(ChronoUnit.SECONDS)) + .entity(entity) + .build(); + return task; + } + + public static HttpEntity toHttpEntity(ToXContentObject object) throws IOException { + return new StringEntity(toJsonString(object), APPLICATION_JSON); + } + + public static HttpEntity toHttpEntity(String jsonString) throws IOException { + return new StringEntity(jsonString, APPLICATION_JSON); + } + + public static String toJsonString(ToXContentObject object) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder(); + return TestHelpers.xContentBuilderToString(object.toXContent(builder, ToXContent.EMPTY_PARAMS)); + } + + public static RestStatus restStatus(Response response) { + return RestStatus.fromCode(response.getStatusLine().getStatusCode()); + } + + public static SearchHits createSearchHits(int totalHits) { + List hitList = new ArrayList<>(); + IntStream.range(0, totalHits).forEach(i -> hitList.add(new SearchHit(i))); + SearchHit[] hitArray = new SearchHit[hitList.size()]; + return new SearchHits(hitList.toArray(hitArray), new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), 1.0F); + } + + public static DiscoveryNode randomDiscoveryNode() { + return new DiscoveryNode(UUIDs.randomBase64UUID(), buildNewFakeTransportAddress(), Version.CURRENT); + } + + public static SearchRequest matchAllRequest() { + BoolQueryBuilder query = new BoolQueryBuilder().filter(new MatchAllQueryBuilder()); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); + return new SearchRequest().source(searchSourceBuilder); + } + + public static Map parseStatsResult(String statsResult) throws IOException { + XContentParser parser = TestHelpers.parser(statsResult); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + Map adStats = new HashMap<>(); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + if (fieldName.equals("nodes")) { + Map nodesAdStats = new HashMap<>(); + adStats.put("nodes", nodesAdStats); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String nodeId = parser.currentName(); + Map nodeAdStats = new HashMap<>(); + nodesAdStats.put(nodeId, nodeAdStats); + parser.nextToken(); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String nodeStatName = parser.currentName(); + XContentParser.Token token = parser.nextToken(); + if (nodeStatName.equals("models")) { + parser.skipChildren(); + } else if (nodeStatName.contains("_count")) { + nodeAdStats.put(nodeStatName, parser.longValue()); + } else { + nodeAdStats.put(nodeStatName, parser.text()); + } + } + } + } else if (fieldName.contains("_count")) { + adStats.put(fieldName, parser.longValue()); + } else { + adStats.put(fieldName, parser.text()); + } + } + return adStats; + } + + public static DetectorValidationIssue randomDetectorValidationIssue() { + DetectorValidationIssue issue = new DetectorValidationIssue( + ValidationAspect.DETECTOR, + ValidationIssueType.NAME, + randomAlphaOfLength(5) + ); + return issue; + } + + public static DetectorValidationIssue randomDetectorValidationIssueWithSubIssues(Map subIssues) { + DetectorValidationIssue issue = new DetectorValidationIssue( + ValidationAspect.DETECTOR, + ValidationIssueType.NAME, + randomAlphaOfLength(5), + subIssues, + null + ); + return issue; + } + + public static DetectorValidationIssue randomDetectorValidationIssueWithDetectorIntervalRec(long intervalRec) { + DetectorValidationIssue issue = new DetectorValidationIssue( + ValidationAspect.MODEL, + ValidationIssueType.DETECTION_INTERVAL, + ADCommonMessages.DETECTOR_INTERVAL_REC + intervalRec, + null, + new IntervalTimeConfiguration(intervalRec, ChronoUnit.MINUTES) + ); + return issue; + } + + public static ClusterState createClusterState() { + final Map mappings = new HashMap<>(); + + mappings + .put( + CommonName.JOB_INDEX, + IndexMetadata + .builder("test") + .settings( + Settings + .builder() + .put("index.number_of_shards", 1) + .put("index.number_of_replicas", 1) + .put("index.version.created", Version.CURRENT.id) + ) + .build() + ); + + // The usage of Collections.unmodifiableMap is due to replacing ImmutableOpenMap + // with java.util.Map in the core (refer to https://tinyurl.com/5fjdccs3 and https://tinyurl.com/5fjdccs3) + // The meaning and logic of the code stay the same. + Metadata metaData = Metadata.builder().indices(Collections.unmodifiableMap(mappings)).build(); + ClusterState clusterState = new ClusterState( + new ClusterName("test_name"), + 1l, + "uuid", + metaData, + null, + null, + null, + new HashMap<>(), + 1, + true + ); + return clusterState; + } + + public static ImputationOption randomImputationOption() { + double[] defaultFill = DoubleStream.generate(OpenSearchTestCase::randomDouble).limit(10).toArray(); + ImputationOption fixedValue = new ImputationOption(ImputationMethod.FIXED_VALUES, Optional.of(defaultFill), false); + ImputationOption linear = new ImputationOption(ImputationMethod.LINEAR, Optional.of(defaultFill), false); + ImputationOption linearIntSensitive = new ImputationOption(ImputationMethod.LINEAR, Optional.of(defaultFill), true); + ImputationOption zero = new ImputationOption(ImputationMethod.ZERO); + ImputationOption previous = new ImputationOption(ImputationMethod.PREVIOUS); + + List options = List.of(fixedValue, linear, linearIntSensitive, zero, previous); + + // Select a random option + int randomIndex = Randomness.get().nextInt(options.size()); + return options.get(randomIndex); + } + + public static class ForecasterBuilder { + String forecasterId; + Long version; + String name; + String description; + String timeField; + List indices; + List features; + QueryBuilder filterQuery; + TimeConfiguration forecastInterval; + TimeConfiguration windowDelay; + Integer shingleSize; + Map uiMetadata; + Integer schemaVersion; + Instant lastUpdateTime; + List categoryFields; + User user; + String resultIndex; + Integer horizon; + ImputationOption imputationOption; + + ForecasterBuilder() throws IOException { + forecasterId = randomAlphaOfLength(10); + version = randomLong(); + name = randomAlphaOfLength(10); + description = randomAlphaOfLength(20); + timeField = randomAlphaOfLength(5); + indices = ImmutableList.of(randomAlphaOfLength(10)); + features = ImmutableList.of(randomFeature()); + filterQuery = randomQuery(); + forecastInterval = randomIntervalTimeConfiguration(); + windowDelay = randomIntervalTimeConfiguration(); + shingleSize = randomIntBetween(1, 20); + uiMetadata = ImmutableMap.of(randomAlphaOfLength(5), randomAlphaOfLength(10)); + schemaVersion = randomInt(); + lastUpdateTime = Instant.now().truncatedTo(ChronoUnit.SECONDS); + categoryFields = ImmutableList.of(randomAlphaOfLength(5)); + user = randomUser(); + resultIndex = null; + horizon = randomIntBetween(1, 20); + imputationOption = randomImputationOption(); + } + + public static ForecasterBuilder newInstance() throws IOException { + return new ForecasterBuilder(); + } + + public ForecasterBuilder setConfigId(String configId) { + this.forecasterId = configId; + return this; + } + + public ForecasterBuilder setVersion(Long version) { + this.version = version; + return this; + } + + public ForecasterBuilder setName(String name) { + this.name = name; + return this; + } + + public ForecasterBuilder setDescription(String description) { + this.description = description; + return this; + } + + public ForecasterBuilder setTimeField(String timeField) { + this.timeField = timeField; + return this; + } + + public ForecasterBuilder setIndices(List indices) { + this.indices = indices; + return this; + } + + public ForecasterBuilder setFeatureAttributes(List featureAttributes) { + this.features = featureAttributes; + return this; + } + + public ForecasterBuilder setFilterQuery(QueryBuilder filterQuery) { + this.filterQuery = filterQuery; + return this; + } + + public ForecasterBuilder setDetectionInterval(TimeConfiguration forecastInterval) { + this.forecastInterval = forecastInterval; + return this; + } + + public ForecasterBuilder setWindowDelay(TimeConfiguration windowDelay) { + this.windowDelay = windowDelay; + return this; + } + + public ForecasterBuilder setShingleSize(Integer shingleSize) { + this.shingleSize = shingleSize; + return this; + } + + public ForecasterBuilder setUiMetadata(Map uiMetadata) { + this.uiMetadata = uiMetadata; + return this; + } + + public ForecasterBuilder setSchemaVersion(Integer schemaVersion) { + this.schemaVersion = schemaVersion; + return this; + } + + public ForecasterBuilder setLastUpdateTime(Instant lastUpdateTime) { + this.lastUpdateTime = lastUpdateTime; + return this; + } + + public ForecasterBuilder setCategoryFields(List categoryFields) { + this.categoryFields = categoryFields; + return this; + } + + public ForecasterBuilder setUser(User user) { + this.user = user; + return this; + } + + public ForecasterBuilder setCustomResultIndex(String resultIndex) { + this.resultIndex = resultIndex; + return this; + } + + public ForecasterBuilder setNullImputationOption() { + this.imputationOption = null; + return this; + } + + public Forecaster build() { + return new Forecaster( + forecasterId, + version, + name, + description, + timeField, + indices, + features, + filterQuery, + forecastInterval, + windowDelay, + shingleSize, + uiMetadata, + schemaVersion, + lastUpdateTime, + categoryFields, + user, + resultIndex, + horizon, + imputationOption + ); + } + } + + public static Forecaster randomForecaster() throws IOException { + return new Forecaster( + randomAlphaOfLength(10), + randomLong(), + randomAlphaOfLength(10), + randomAlphaOfLength(20), + randomAlphaOfLength(5), + ImmutableList.of(randomAlphaOfLength(10)), + ImmutableList.of(randomFeature()), + randomQuery(), + randomIntervalTimeConfiguration(), + randomIntervalTimeConfiguration(), + randomIntBetween(1, 20), + ImmutableMap.of(randomAlphaOfLength(5), randomAlphaOfLength(10)), + randomInt(), + Instant.now().truncatedTo(ChronoUnit.SECONDS), + ImmutableList.of(randomAlphaOfLength(5)), + randomUser(), + null, + randomIntBetween(1, 20), + randomImputationOption() + ); + } +} diff --git a/src/test/java/org/opensearch/timeseries/TimeSeriesPluginTests.java b/src/test/java/org/opensearch/timeseries/TimeSeriesPluginTests.java new file mode 100644 index 000000000..ff170a1d5 --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/TimeSeriesPluginTests.java @@ -0,0 +1,87 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.List; + +import org.apache.commons.pool2.impl.GenericObjectPool; +import org.opensearch.ad.ADUnitTestCase; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.env.Environment; + +import io.protostuff.LinkedBuffer; + +public class TimeSeriesPluginTests extends ADUnitTestCase { + TimeSeriesAnalyticsPlugin plugin; + + @Override + public void setUp() throws Exception { + super.setUp(); + plugin = new TimeSeriesAnalyticsPlugin(); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + plugin.close(); + } + + /** + * We have legacy setting. TimeSeriesAnalyticsPlugin's createComponents can trigger + * warning when using these legacy settings. + */ + @Override + protected boolean enableWarningsCheck() { + return false; + } + + public void testDeserializeRCFBufferPool() throws Exception { + Settings.Builder settingsBuilder = Settings.builder(); + List> allSettings = plugin.getSettings(); + for (Setting setting : allSettings) { + Object defaultVal = setting.getDefault(Settings.EMPTY); + if (defaultVal instanceof Boolean) { + settingsBuilder.put(setting.getKey(), (Boolean) defaultVal); + } else { + settingsBuilder.put(setting.getKey(), defaultVal.toString()); + } + } + Settings settings = settingsBuilder.build(); + + Setting[] settingArray = new Setting[allSettings.size()]; + settingArray = allSettings.toArray(settingArray); + + ClusterSettings clusterSettings = clusterSetting(settings, settingArray); + ClusterService clusterService = new ClusterService(settings, clusterSettings, null); + + Environment environment = mock(Environment.class); + when(environment.settings()).thenReturn(settings); + plugin.createComponents(mock(Client.class), clusterService, null, null, null, null, environment, null, null, null, null); + GenericObjectPool deserializeRCFBufferPool = plugin.serializeRCFBufferPool; + deserializeRCFBufferPool.addObject(); + LinkedBuffer buffer = deserializeRCFBufferPool.borrowObject(); + assertTrue(null != buffer); + } + + public void testOverriddenJobTypeAndIndex() { + assertEquals("opensearch_time_series_analytics", plugin.getJobType()); + assertEquals(".opendistro-anomaly-detector-jobs", plugin.getJobIndex()); + } + +} diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorPluginTests.java b/src/test/java/org/opensearch/timeseries/TimeSeriesPluginTests.java-e similarity index 89% rename from src/test/java/org/opensearch/ad/AnomalyDetectorPluginTests.java rename to src/test/java/org/opensearch/timeseries/TimeSeriesPluginTests.java-e index e152e0b72..1ad8032b5 100644 --- a/src/test/java/org/opensearch/ad/AnomalyDetectorPluginTests.java +++ b/src/test/java/org/opensearch/timeseries/TimeSeriesPluginTests.java-e @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad; +package org.opensearch.timeseries; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -17,6 +17,7 @@ import java.util.List; import org.apache.commons.pool2.impl.GenericObjectPool; +import org.opensearch.ad.ADUnitTestCase; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; @@ -26,13 +27,13 @@ import io.protostuff.LinkedBuffer; -public class AnomalyDetectorPluginTests extends ADUnitTestCase { - AnomalyDetectorPlugin plugin; +public class TimeSeriesPluginTests extends ADUnitTestCase { + TimeSeriesAnalyticsPlugin plugin; @Override public void setUp() throws Exception { super.setUp(); - plugin = new AnomalyDetectorPlugin(); + plugin = new TimeSeriesAnalyticsPlugin(); } @Override @@ -42,7 +43,7 @@ public void tearDown() throws Exception { } /** - * We have legacy setting. AnomalyDetectorPlugin's createComponents can trigger + * We have legacy setting. TimeSeriesAnalyticsPlugin's createComponents can trigger * warning when using these legacy settings. */ @Override diff --git a/src/test/java/org/opensearch/timeseries/common/exception/ValidationExceptionTests.java-e b/src/test/java/org/opensearch/timeseries/common/exception/ValidationExceptionTests.java-e new file mode 100644 index 000000000..bfcd5ad7a --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/common/exception/ValidationExceptionTests.java-e @@ -0,0 +1,51 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.common.exception; + +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; + +public class ValidationExceptionTests extends OpenSearchTestCase { + public void testConstructorDetector() { + String message = randomAlphaOfLength(5); + ValidationException exception = new ValidationException(message, ValidationIssueType.NAME, ValidationAspect.DETECTOR); + assertEquals(ValidationIssueType.NAME, exception.getType()); + assertEquals(ValidationAspect.DETECTOR, exception.getAspect()); + } + + public void testConstructorModel() { + String message = randomAlphaOfLength(5); + ValidationException exception = new ValidationException(message, ValidationIssueType.CATEGORY, ValidationAspect.MODEL); + assertEquals(ValidationIssueType.CATEGORY, exception.getType()); + assertEquals(ValidationAspect.getName(CommonName.MODEL_ASPECT), exception.getAspect()); + } + + public void testToString() { + String message = randomAlphaOfLength(5); + ValidationException exception = new ValidationException(message, ValidationIssueType.NAME, ValidationAspect.DETECTOR); + String exceptionString = exception.toString(); + logger.info("exception string: " + exceptionString); + ValidationException exceptionNoType = new ValidationException(message, ValidationIssueType.NAME, null); + String exceptionStringNoType = exceptionNoType.toString(); + logger.info("exception string no type: " + exceptionStringNoType); + } + + public void testForecasterAspect() { + String message = randomAlphaOfLength(5); + ValidationException exception = new ValidationException(message, ValidationIssueType.CATEGORY, ValidationAspect.FORECASTER); + assertEquals(ValidationIssueType.CATEGORY, exception.getType()); + assertEquals(ValidationAspect.getName(ForecastCommonName.FORECASTER_ASPECT), exception.getAspect()); + } +} diff --git a/src/test/java/org/opensearch/timeseries/dataprocessor/FixedValueImputerTests.java-e b/src/test/java/org/opensearch/timeseries/dataprocessor/FixedValueImputerTests.java-e new file mode 100644 index 000000000..81b9b5bfb --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/dataprocessor/FixedValueImputerTests.java-e @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.dataprocessor; + +import static org.junit.Assert.assertArrayEquals; + +import org.junit.Test; + +public class FixedValueImputerTests { + + @Test + public void testImpute() { + // Initialize the FixedValueImputer with some fixed values + double[] fixedValues = { 2.0, 3.0 }; + FixedValueImputer imputer = new FixedValueImputer(fixedValues); + + // Create a sample array with some missing values (Double.NaN) + double[][] samples = { { 1.0, Double.NaN, 3.0 }, { Double.NaN, 2.0, 3.0 } }; + + // Call the impute method + double[][] imputed = imputer.impute(samples, 3); + + // Check the results + double[][] expected = { { 1.0, 2.0, 3.0 }, { 3.0, 2.0, 3.0 } }; + double delta = 0.0001; + + for (int i = 0; i < expected.length; i++) { + assertArrayEquals("The arrays are not equal", expected[i], imputed[i], delta); + } + } +} diff --git a/src/test/java/org/opensearch/timeseries/dataprocessor/ImputationOptionTests.java b/src/test/java/org/opensearch/timeseries/dataprocessor/ImputationOptionTests.java index f1cb8b36e..9adb57ed9 100644 --- a/src/test/java/org/opensearch/timeseries/dataprocessor/ImputationOptionTests.java +++ b/src/test/java/org/opensearch/timeseries/dataprocessor/ImputationOptionTests.java @@ -8,10 +8,10 @@ import java.io.IOException; import java.util.Optional; -import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; diff --git a/src/test/java/org/opensearch/timeseries/dataprocessor/ImputationOptionTests.java-e b/src/test/java/org/opensearch/timeseries/dataprocessor/ImputationOptionTests.java-e new file mode 100644 index 000000000..758baac1d --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/dataprocessor/ImputationOptionTests.java-e @@ -0,0 +1,124 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.dataprocessor; + +import java.io.IOException; +import java.util.Optional; + +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.test.OpenSearchTestCase; + +public class ImputationOptionTests extends OpenSearchTestCase { + + public void testStreamInputAndOutput() throws IOException { + // Prepare the data to be read by the StreamInput object. + ImputationMethod method = ImputationMethod.PREVIOUS; + double[] defaultFill = { 1.0, 2.0, 3.0 }; + + ImputationOption option = new ImputationOption(method, Optional.of(defaultFill), false); + + // Write the ImputationOption to the StreamOutput. + BytesStreamOutput out = new BytesStreamOutput(); + option.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + + // Create an ImputationOption using the mocked StreamInput. + ImputationOption inOption = new ImputationOption(in); + + // Check that the created ImputationOption has the correct values. + assertEquals(method, inOption.getMethod()); + assertArrayEquals(defaultFill, inOption.getDefaultFill().get(), 1e-6); + } + + public void testToXContent() throws IOException { + double[] defaultFill = { 1.0, 2.0, 3.0 }; + ImputationOption imputationOption = new ImputationOption(ImputationMethod.FIXED_VALUES, Optional.of(defaultFill), false); + + String xContent = "{" + "\"method\":\"FIXED_VALUES\"," + "\"defaultFill\":[1.0,2.0,3.0],\"integerSensitive\":false" + "}"; + + XContentBuilder builder = imputationOption.toXContent(JsonXContent.contentBuilder(), ToXContent.EMPTY_PARAMS); + String actualJson = BytesReference.bytes(builder).utf8ToString(); + + assertEquals(xContent, actualJson); + } + + public void testParse() throws IOException { + String xContent = "{" + "\"method\":\"FIXED_VALUES\"," + "\"defaultFill\":[1.0,2.0,3.0],\"integerSensitive\":false" + "}"; + + double[] defaultFill = { 1.0, 2.0, 3.0 }; + ImputationOption imputationOption = new ImputationOption(ImputationMethod.FIXED_VALUES, Optional.of(defaultFill), false); + + try ( + XContentParser parser = JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, xContent) + ) { + // advance to first token + XContentParser.Token token = parser.nextToken(); + if (token != XContentParser.Token.START_OBJECT) { + throw new IOException("Expected data to start with an Object"); + } + + ImputationOption parsedOption = ImputationOption.parse(parser); + + assertEquals(imputationOption.getMethod(), parsedOption.getMethod()); + assertTrue(imputationOption.getDefaultFill().isPresent()); + assertTrue(parsedOption.getDefaultFill().isPresent()); + assertEquals(imputationOption.getDefaultFill().get().length, parsedOption.getDefaultFill().get().length); + for (int i = 0; i < imputationOption.getDefaultFill().get().length; i++) { + assertEquals(imputationOption.getDefaultFill().get()[i], parsedOption.getDefaultFill().get()[i], 0); + } + } + } + + public void testEqualsAndHashCode() { + double[] defaultFill1 = { 1.0, 2.0, 3.0 }; + double[] defaultFill2 = { 4.0, 5.0, 6.0 }; + + ImputationOption option1 = new ImputationOption(ImputationMethod.FIXED_VALUES, Optional.of(defaultFill1), false); + ImputationOption option2 = new ImputationOption(ImputationMethod.FIXED_VALUES, Optional.of(defaultFill1), false); + ImputationOption option3 = new ImputationOption(ImputationMethod.LINEAR, Optional.of(defaultFill2), false); + + // Test reflexivity + assertTrue(option1.equals(option1)); + + // Test symmetry + assertTrue(option1.equals(option2)); + assertTrue(option2.equals(option1)); + + // Test transitivity + ImputationOption option2Clone = new ImputationOption(ImputationMethod.FIXED_VALUES, Optional.of(defaultFill1), false); + assertTrue(option1.equals(option2)); + assertTrue(option2.equals(option2Clone)); + assertTrue(option1.equals(option2Clone)); + + // Test consistency: ultiple invocations of a.equals(b) consistently return true or consistently return false. + assertTrue(option1.equals(option2)); + assertTrue(option1.equals(option2)); + + // Test non-nullity + assertFalse(option1.equals(null)); + + // Test hashCode consistency + assertEquals(option1.hashCode(), option1.hashCode()); + + // Test hashCode equality + assertTrue(option1.equals(option2)); + assertEquals(option1.hashCode(), option2.hashCode()); + + // Test inequality + assertFalse(option1.equals(option3)); + assertNotEquals(option1.hashCode(), option3.hashCode()); + } +} diff --git a/src/test/java/org/opensearch/timeseries/dataprocessor/IntegerSensitiveLinearUniformImputerTests.java-e b/src/test/java/org/opensearch/timeseries/dataprocessor/IntegerSensitiveLinearUniformImputerTests.java-e new file mode 100644 index 000000000..03e8b6cb3 --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/dataprocessor/IntegerSensitiveLinearUniformImputerTests.java-e @@ -0,0 +1,72 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.dataprocessor; + +import static org.junit.Assert.assertArrayEquals; + +import java.util.Arrays; +import java.util.Collection; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +/** + * Compared to MultiFeatureLinearUniformImputerTests, outputs are different and + * integerSensitive is enabled + * + */ +@RunWith(Parameterized.class) +public class IntegerSensitiveLinearUniformImputerTests { + + @Parameters + public static Collection data() { + double[][] singleComponent = { { -1.0, 2.0 }, { 1.0, 1.0 } }; + double[][] multiComponent = { { 0.0, 1.0, -1.0 }, { 1.0, 1.0, 1.0 } }; + + return Arrays + .asList( + new Object[][] { + // after integer sensitive rint rounding + { singleComponent, 2, singleComponent }, + { singleComponent, 3, new double[][] { { -1.0, 0, 2.0 }, { 1.0, 1.0, 1.0 } } }, + { singleComponent, 4, new double[][] { { -1.0, 0.0, 1.0, 2.0 }, { 1.0, 1.0, 1.0, 1.0 } } }, + { multiComponent, 3, multiComponent }, + { multiComponent, 4, new double[][] { { 0.0, 1.0, 0.0, -1.0 }, { 1.0, 1.0, 1.0, 1.0 } } }, + { multiComponent, 5, new double[][] { { 0.0, 0.0, 1.0, 0.0, -1.0 }, { 1.0, 1.0, 1.0, 1.0, 1.0 } } }, + { multiComponent, 6, new double[][] { { 0.0, 0.0, 1.0, 1.0, -0.0, -1.0 }, { 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 } } }, } + ); + } + + private double[][] input; + private int numInterpolants; + private double[][] expected; + private Imputer imputer; + + public IntegerSensitiveLinearUniformImputerTests(double[][] input, int numInterpolants, double[][] expected) { + this.input = input; + this.numInterpolants = numInterpolants; + this.expected = expected; + } + + @Before + public void setUp() { + this.imputer = new LinearUniformImputer(true); + } + + @Test + public void testImputation() { + double[][] actual = imputer.impute(input, numInterpolants); + double delta = 1e-8; + int numFeatures = expected.length; + + for (int i = 0; i < numFeatures; i++) { + assertArrayEquals(expected[i], actual[i], delta); + } + } +} diff --git a/src/test/java/org/opensearch/timeseries/dataprocessor/MultiFeatureLinearUniformImputerTests.java-e b/src/test/java/org/opensearch/timeseries/dataprocessor/MultiFeatureLinearUniformImputerTests.java-e new file mode 100644 index 000000000..3656be278 --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/dataprocessor/MultiFeatureLinearUniformImputerTests.java-e @@ -0,0 +1,74 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.dataprocessor; + +import static org.junit.Assert.assertArrayEquals; + +import java.util.Arrays; +import java.util.Collection; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) +public class MultiFeatureLinearUniformImputerTests { + + @Parameters + public static Collection data() { + double[][] singleComponent = { { -1.0, 2.0 }, { 1.0, 1.0 } }; + double[][] multiComponent = { { 0.0, 1.0, -1.0 }, { 1.0, 1.0, 1.0 } }; + double oneThird = 1.0 / 3.0; + + return Arrays + .asList( + new Object[][] { + // no integer sensitive rint rounding at the end of singleFeatureImpute. + { singleComponent, 2, singleComponent }, + { singleComponent, 3, new double[][] { { -1.0, 0.5, 2.0 }, { 1.0, 1.0, 1.0 } } }, + { singleComponent, 4, new double[][] { { -1.0, 0.0, 1.0, 2.0 }, { 1.0, 1.0, 1.0, 1.0 } } }, + { multiComponent, 3, multiComponent }, + { multiComponent, 4, new double[][] { { 0.0, 2 * oneThird, oneThird, -1.0 }, { 1.0, 1.0, 1.0, 1.0 } } }, + { multiComponent, 5, new double[][] { { 0.0, 0.5, 1.0, 0.0, -1.0 }, { 1.0, 1.0, 1.0, 1.0, 1.0 } } }, + { multiComponent, 6, new double[][] { { 0.0, 0.4, 0.8, 0.6, -0.2, -1.0 }, { 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 } } }, } + ); + } + + private double[][] input; + private int numInterpolants; + private double[][] expected; + private Imputer imputer; + + public MultiFeatureLinearUniformImputerTests(double[][] input, int numInterpolants, double[][] expected) { + this.input = input; + this.numInterpolants = numInterpolants; + this.expected = expected; + } + + @Before + public void setUp() { + this.imputer = new LinearUniformImputer(false); + } + + @Test + public void testImputation() { + double[][] actual = imputer.impute(input, numInterpolants); + double delta = 1e-8; + int numFeatures = expected.length; + + for (int i = 0; i < numFeatures; i++) { + assertArrayEquals(expected[i], actual[i], delta); + } + } +} diff --git a/src/test/java/org/opensearch/timeseries/dataprocessor/PreviousValueImputerTests.java-e b/src/test/java/org/opensearch/timeseries/dataprocessor/PreviousValueImputerTests.java-e new file mode 100644 index 000000000..ff0cfbba2 --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/dataprocessor/PreviousValueImputerTests.java-e @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.dataprocessor; + +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.jupiter.api.Test; + +public class PreviousValueImputerTests { + @Test + void testSingleFeatureImpute() { + PreviousValueImputer imputer = new PreviousValueImputer(); + + double[] samples = { 1.0, Double.NaN, 3.0, Double.NaN, 5.0 }; + double[] expected = { 1.0, 1.0, 3.0, 3.0, 5.0 }; + + assertArrayEquals(expected, imputer.singleFeatureImpute(samples, 0), "Imputation failed"); + + // The second test checks whether the method removes leading Double.NaN values from the array + samples = new double[] { Double.NaN, 2.0, Double.NaN, 4.0 }; + expected = new double[] { Double.NaN, 2.0, 2.0, 4.0 }; + + assertArrayEquals(expected, imputer.singleFeatureImpute(samples, 0), "Imputation failed with leading NaN"); + } +} diff --git a/src/test/java/org/opensearch/timeseries/dataprocessor/SingleFeatureLinearUniformImputerTests.java-e b/src/test/java/org/opensearch/timeseries/dataprocessor/SingleFeatureLinearUniformImputerTests.java-e new file mode 100644 index 000000000..17aae0422 --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/dataprocessor/SingleFeatureLinearUniformImputerTests.java-e @@ -0,0 +1,50 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.dataprocessor; + +import static org.junit.Assert.assertTrue; + +import java.util.Arrays; + +import junitparams.JUnitParamsRunner; +import junitparams.Parameters; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(JUnitParamsRunner.class) +public class SingleFeatureLinearUniformImputerTests { + + private Imputer imputer; + + @Before + public void setup() { + imputer = new LinearUniformImputer(false); + } + + private Object[] imputeData() { + return new Object[] { + new Object[] { new double[] { 25.25, 25.75 }, 3, new double[] { 25.25, 25.5, 25.75 } }, + new Object[] { new double[] { 25, 75 }, 3, new double[] { 25, 50, 75 } }, + new Object[] { new double[] { 25, 75.5 }, 3, new double[] { 25, 50.25, 75.5 } }, + new Object[] { new double[] { 25.25, 25.75 }, 3, new double[] { 25.25, 25.5, 25.75 } }, + new Object[] { new double[] { 25, 75 }, 3, new double[] { 25, 50, 75 } }, + new Object[] { new double[] { 25, 75.5 }, 3, new double[] { 25, 50.25, 75.5 } } }; + } + + @Test + @Parameters(method = "imputeData") + public void impute_returnExpected(double[] samples, int num, double[] expected) { + assertTrue(Arrays.equals(expected, imputer.singleFeatureImpute(samples, num))); + } +} diff --git a/src/test/java/org/opensearch/timeseries/dataprocessor/ZeroImputerTests.java-e b/src/test/java/org/opensearch/timeseries/dataprocessor/ZeroImputerTests.java-e new file mode 100644 index 000000000..a13189db2 --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/dataprocessor/ZeroImputerTests.java-e @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.dataprocessor; + +import static org.junit.Assert.assertArrayEquals; + +import junitparams.JUnitParamsRunner; +import junitparams.Parameters; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(JUnitParamsRunner.class) +public class ZeroImputerTests { + + private Imputer imputer; + + @Before + public void setup() { + imputer = new ZeroImputer(); + } + + private Object[] imputeData() { + return new Object[] { + new Object[] { new double[] { 25.25, Double.NaN, 25.75 }, 3, new double[] { 25.25, 0, 25.75 } }, + new Object[] { new double[] { Double.NaN, 25, 75 }, 3, new double[] { 0, 25, 75 } }, + new Object[] { new double[] { 25, 75.5, Double.NaN }, 3, new double[] { 25, 75.5, 0 } }, }; + } + + @Test + @Parameters(method = "imputeData") + public void impute_returnExpected(double[] samples, int num, double[] expected) { + assertArrayEquals("The arrays are not equal", expected, imputer.singleFeatureImpute(samples, num), 0.001); + } +} diff --git a/src/test/java/org/opensearch/timeseries/indices/IndexManagementIntegTestCase.java b/src/test/java/org/opensearch/timeseries/indices/IndexManagementIntegTestCase.java index dd1ec06da..0ef089f3c 100644 --- a/src/test/java/org/opensearch/timeseries/indices/IndexManagementIntegTestCase.java +++ b/src/test/java/org/opensearch/timeseries/indices/IndexManagementIntegTestCase.java @@ -20,10 +20,10 @@ import org.mockito.ArgumentCaptor; import org.opensearch.action.ActionListener; -import org.opensearch.common.bytes.BytesArray; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.constant.CommonMessages; diff --git a/src/test/java/org/opensearch/timeseries/indices/IndexManagementIntegTestCase.java-e b/src/test/java/org/opensearch/timeseries/indices/IndexManagementIntegTestCase.java-e new file mode 100644 index 000000000..f930fab22 --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/indices/IndexManagementIntegTestCase.java-e @@ -0,0 +1,104 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.indices; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.mockito.ArgumentCaptor; +import org.opensearch.action.ActionListener; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.function.ExecutorFunction; + +public abstract class IndexManagementIntegTestCase & TimeSeriesIndex, ISMType extends IndexManagement> + extends OpenSearchIntegTestCase { + + public void validateCustomIndexForBackendJob(ISMType indices, String resultMapping) throws IOException, InterruptedException { + + Map asMap = XContentHelper.convertToMap(new BytesArray(resultMapping), false, XContentType.JSON).v2(); + String resultIndex = "test_index"; + + client() + .admin() + .indices() + .prepareCreate(resultIndex) + .setSettings(Settings.builder().put("index.number_of_shards", 1).put("index.number_of_replicas", 0)) + .setMapping(asMap) + .get(); + ensureGreen(resultIndex); + + String securityLogId = "logId"; + String user = "testUser"; + List roles = Arrays.asList("role1", "role2"); + ExecutorFunction function = mock(ExecutorFunction.class); + ActionListener listener = mock(ActionListener.class); + + CountDownLatch latch = new CountDownLatch(1); + doAnswer(invocation -> { + latch.countDown(); + return null; + }).when(function).execute(); + latch.await(20, TimeUnit.SECONDS); + indices.validateCustomIndexForBackendJob(resultIndex, securityLogId, user, roles, function, listener); + verify(listener, never()).onFailure(any(Exception.class)); + } + + public void validateCustomIndexForBackendJobInvalidMapping(ISMType indices) { + String resultIndex = "test_index"; + + client() + .admin() + .indices() + .prepareCreate(resultIndex) + .setSettings(Settings.builder().put("index.number_of_shards", 1).put("index.number_of_replicas", 0)) + .setMapping("ip", "type=ip") + .get(); + ensureGreen(resultIndex); + + String securityLogId = "logId"; + String user = "testUser"; + List roles = Arrays.asList("role1", "role2"); + ExecutorFunction function = mock(ExecutorFunction.class); + ActionListener listener = mock(ActionListener.class); + + indices.validateCustomIndexForBackendJob(resultIndex, securityLogId, user, roles, function, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(EndRunException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertEquals("Result index mapping is not correct", exceptionCaptor.getValue().getMessage()); + } + + public void validateCustomIndexForBackendJobNoIndex(ISMType indices) { + String resultIndex = "testIndex"; + String securityLogId = "logId"; + String user = "testUser"; + List roles = Arrays.asList("role1", "role2"); + ExecutorFunction function = mock(ExecutorFunction.class); + ActionListener listener = mock(ActionListener.class); + + indices.validateCustomIndexForBackendJob(resultIndex, securityLogId, user, roles, function, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(EndRunException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertEquals(CommonMessages.CAN_NOT_FIND_RESULT_INDEX + resultIndex, exceptionCaptor.getValue().getMessage()); + } +} diff --git a/src/test/java/org/opensearch/timeseries/util/LTrimTests.java-e b/src/test/java/org/opensearch/timeseries/util/LTrimTests.java-e new file mode 100644 index 000000000..384982828 --- /dev/null +++ b/src/test/java/org/opensearch/timeseries/util/LTrimTests.java-e @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.util; + +import org.opensearch.test.OpenSearchTestCase; + +public class LTrimTests extends OpenSearchTestCase { + + public void testLtrimEmptyArray() { + + double[][] input = {}; + double[][] expectedOutput = {}; + + assertArrayEquals(expectedOutput, DataUtil.ltrim(input)); + } + + public void testLtrimAllNaN() { + + double[][] input = { { Double.NaN, Double.NaN }, { Double.NaN, Double.NaN }, { Double.NaN, Double.NaN } }; + double[][] expectedOutput = {}; + + assertArrayEquals(expectedOutput, DataUtil.ltrim(input)); + } + + public void testLtrimSomeNaN() { + + double[][] input = { { Double.NaN, Double.NaN }, { 1.0, 2.0 }, { 3.0, 4.0 } }; + double[][] expectedOutput = { { 1.0, 2.0 }, { 3.0, 4.0 } }; + + assertArrayEquals(expectedOutput, DataUtil.ltrim(input)); + } + + public void testLtrimNoNaN() { + + double[][] input = { { 1.0, 2.0 }, { 3.0, 4.0 }, { 5.0, 6.0 } }; + double[][] expectedOutput = { { 1.0, 2.0 }, { 3.0, 4.0 }, { 5.0, 6.0 } }; + + assertArrayEquals(expectedOutput, DataUtil.ltrim(input)); + } +} diff --git a/src/test/java/test/org/opensearch/ad/util/ClusterCreation.java-e b/src/test/java/test/org/opensearch/ad/util/ClusterCreation.java-e new file mode 100644 index 000000000..3eb4fa80a --- /dev/null +++ b/src/test/java/test/org/opensearch/ad/util/ClusterCreation.java-e @@ -0,0 +1,109 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package test.org.opensearch.ad.util; + +import static org.mockito.Mockito.mock; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; +import static org.opensearch.cluster.node.DiscoveryNodeRole.DATA_ROLE; + +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.Version; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.common.transport.TransportAddress; + +public class ClusterCreation { + /** + * Creates a cluster state where local node and clusterManager node can be specified + * + * @param localNode node in allNodes that is the local node + * @param clusterManagerNode node in allNodes that is the clusterManager node. Can be null if no clusterManager exists + * @param allNodes all nodes in the cluster + * @return cluster state + */ + public static ClusterState state( + ClusterName name, + DiscoveryNode localNode, + DiscoveryNode clusterManagerNode, + List allNodes + ) { + DiscoveryNodes.Builder discoBuilder = DiscoveryNodes.builder(); + for (DiscoveryNode node : allNodes) { + discoBuilder.add(node); + } + if (clusterManagerNode != null) { + discoBuilder.clusterManagerNodeId(clusterManagerNode.getId()); + } + discoBuilder.localNodeId(localNode.getId()); + + ClusterState.Builder state = ClusterState.builder(name); + state.nodes(discoBuilder); + state.metadata(Metadata.builder().generateClusterUuidIfNeeded()); + return state.build(); + } + + /** + * Create data node map + * @param numDataNodes the number of data nodes + * @return data nodes map + * + * TODO: ModelManagerTests has the same method. Refactor. + */ + public static Map createDataNodes(int numDataNodes) { + Map dataNodes = new HashMap<>(); + for (int i = 0; i < numDataNodes; i++) { + dataNodes.put("foo" + i, mock(DiscoveryNode.class)); + } + return Collections.unmodifiableMap(dataNodes); + } + + /** + * Create a cluster state with 1 clusterManager node and a few data nodes + * @param numDataNodes the number of data nodes + * @return the cluster state + */ + public static ClusterState state(int numDataNodes) { + DiscoveryNode clusterManagerNode = new DiscoveryNode( + "foo0", + "foo0", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + List allNodes = new ArrayList<>(); + allNodes.add(clusterManagerNode); + for (int i = 1; i <= numDataNodes - 1; i++) { + allNodes + .add( + new DiscoveryNode( + "foo" + i, + "foo" + i, + new TransportAddress(InetAddress.getLoopbackAddress(), 9300 + i), + Collections.emptyMap(), + Collections.singleton(DATA_ROLE), + Version.CURRENT + ) + ); + } + return state(new ClusterName("test"), clusterManagerNode, clusterManagerNode, allNodes); + } +} diff --git a/src/test/java/test/org/opensearch/ad/util/FakeNode.java b/src/test/java/test/org/opensearch/ad/util/FakeNode.java index 1af160f91..6cacecc95 100644 --- a/src/test/java/test/org/opensearch/ad/util/FakeNode.java +++ b/src/test/java/test/org/opensearch/ad/util/FakeNode.java @@ -36,7 +36,6 @@ import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.io.stream.NamedWriteableRegistry; import org.opensearch.common.lease.Releasable; import org.opensearch.common.network.NetworkService; import org.opensearch.common.settings.ClusterSettings; @@ -45,6 +44,7 @@ import org.opensearch.common.transport.BoundTransportAddress; import org.opensearch.common.transport.TransportAddress; import org.opensearch.common.util.PageCacheRecycler; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.indices.breaker.NoneCircuitBreakerService; import org.opensearch.tasks.TaskManager; import org.opensearch.tasks.TaskResourceTrackingService; diff --git a/src/test/java/test/org/opensearch/ad/util/FakeNode.java-e b/src/test/java/test/org/opensearch/ad/util/FakeNode.java-e new file mode 100644 index 000000000..220ead1e5 --- /dev/null +++ b/src/test/java/test/org/opensearch/ad/util/FakeNode.java-e @@ -0,0 +1,176 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package test.org.opensearch.ad.util; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; +import static org.opensearch.test.ClusterServiceUtils.createClusterService; +import static org.opensearch.test.ClusterServiceUtils.setState; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.core.Logger; +import org.apache.lucene.util.SetOnce; +import org.opensearch.Version; +import org.opensearch.action.admin.cluster.node.tasks.cancel.TransportCancelTasksAction; +import org.opensearch.action.admin.cluster.node.tasks.list.TransportListTasksAction; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.cluster.ClusterModule; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.common.lease.Releasable; +import org.opensearch.common.network.NetworkService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.transport.BoundTransportAddress; +import org.opensearch.common.transport.TransportAddress; +import org.opensearch.common.util.PageCacheRecycler; +import org.opensearch.indices.breaker.NoneCircuitBreakerService; +import org.opensearch.tasks.TaskManager; +import org.opensearch.tasks.TaskResourceTrackingService; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.tasks.MockTaskManager; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportInterceptor; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.nio.MockNioTransport; + +public class FakeNode implements Releasable { + protected static final Logger LOG = (Logger) LogManager.getLogger(FakeNode.class); + + public FakeNode( + String name, + ThreadPool threadPool, + final Settings nodeSettings, + final Set> settingsSet, + TransportInterceptor transportInterceptor, + Version version + ) { + final Function boundTransportAddressDiscoveryNodeFunction = address -> { + discoveryNode.set(new DiscoveryNode(name, address.publishAddress(), emptyMap(), emptySet(), Version.CURRENT)); + return discoveryNode.get(); + }; + transportService = new TransportService( + Settings.EMPTY, + new MockNioTransport( + Settings.EMPTY, + Version.V_2_1_0, + threadPool, + new NetworkService(Collections.emptyList()), + PageCacheRecycler.NON_RECYCLING_INSTANCE, + new NamedWriteableRegistry(ClusterModule.getNamedWriteables()), + new NoneCircuitBreakerService() + ) { + @Override + public TransportAddress[] addressesFromString(String address) { + return new TransportAddress[] { dns.getOrDefault(address, OpenSearchTestCase.buildNewFakeTransportAddress()) }; + } + }, + threadPool, + transportInterceptor, + boundTransportAddressDiscoveryNodeFunction, + null, + Collections.emptySet() + ) { + @Override + protected TaskManager createTaskManager( + Settings settings, + ClusterSettings clusterSettings, + ThreadPool threadPool, + Set taskHeaders + ) { + if (MockTaskManager.USE_MOCK_TASK_MANAGER_SETTING.get(settings)) { + return new MockTaskManager(settings, threadPool, taskHeaders); + } else { + return super.createTaskManager(settings, clusterSettings, threadPool, taskHeaders); + } + } + }; + + transportService.start(); + Set> internalSettings = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + internalSettings.addAll(settingsSet); + ClusterSettings clusterSettings = new ClusterSettings(nodeSettings, internalSettings); + clusterService = createClusterService(threadPool, discoveryNode.get(), clusterSettings); + clusterService.addStateApplier(transportService.getTaskManager()); + ActionFilters actionFilters = new ActionFilters(emptySet()); + TaskResourceTrackingService taskResourceTrackingService = new TaskResourceTrackingService( + Settings.EMPTY, + clusterService.getClusterSettings(), + threadPool + ); + transportListTasksAction = new TransportListTasksAction( + clusterService, + transportService, + actionFilters, + taskResourceTrackingService + ); + transportCancelTasksAction = new TransportCancelTasksAction(clusterService, transportService, actionFilters); + transportService.acceptIncomingRequests(); + } + + public FakeNode(String name, ThreadPool threadPool, Set> settings) { + this(name, threadPool, Settings.EMPTY, settings, TransportService.NOOP_TRANSPORT_INTERCEPTOR, Version.CURRENT); + } + + public final ClusterService clusterService; + public final TransportService transportService; + private final SetOnce discoveryNode = new SetOnce<>(); + public final TransportListTasksAction transportListTasksAction; + public final TransportCancelTasksAction transportCancelTasksAction; + private final Map dns = new ConcurrentHashMap<>(); + + @Override + public void close() { + clusterService.close(); + transportService.close(); + } + + public String getNodeId() { + return discoveryNode().getId(); + } + + public DiscoveryNode discoveryNode() { + return discoveryNode.get(); + } + + public static void connectNodes(FakeNode... nodes) { + List discoveryNodes = new ArrayList(nodes.length); + DiscoveryNode clusterManager = nodes[0].discoveryNode(); + for (int i = 0; i < nodes.length; i++) { + discoveryNodes.add(nodes[i].discoveryNode()); + } + + for (FakeNode node : nodes) { + setState( + node.clusterService, + ClusterCreation.state(new ClusterName("test"), node.discoveryNode(), clusterManager, discoveryNodes) + ); + } + for (FakeNode nodeA : nodes) { + for (FakeNode nodeB : nodes) { + nodeA.transportService.connectToNode(nodeB.discoveryNode()); + } + } + } +} diff --git a/src/test/java/test/org/opensearch/ad/util/JsonDeserializer.java-e b/src/test/java/test/org/opensearch/ad/util/JsonDeserializer.java-e new file mode 100644 index 000000000..0396fd3ba --- /dev/null +++ b/src/test/java/test/org/opensearch/ad/util/JsonDeserializer.java-e @@ -0,0 +1,459 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package test.org.opensearch.ad.util; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + +import org.apache.commons.lang.StringUtils; +import org.opensearch.ad.common.exception.JsonPathNotFoundException; + +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParseException; +import com.google.gson.JsonParser; + +public class JsonDeserializer { + private static JsonParser parser = new JsonParser(); + + /** + * Search a Gson JsonObject inside a JSON string matching the input path + * expression + * + * @param jsonString + * an encoded JSON string + * @param paths + * path fragments + * @return the matching Gson JsonObject or null in case of no match. + * @throws IOException if the underlying input source has IO issues during parsing + */ + public static JsonElement getChildNode(String jsonString, String... paths) throws IOException { + JsonElement rootElement = parse(jsonString); + return getChildNode(rootElement, paths); + } + + /** + * JsonParseException is an unchecked exception. Rethrow a checked exception + * to force the client to catch that exception. + * @param jsonString json string to parse + * @return a parse tree of JsonElements corresponding to the specified JSON + * @throws IOException if the underlying input source has IO issues during parsing + */ + public static JsonElement parse(String jsonString) throws IOException { + try { + return parser.parse(jsonString); + } catch (JsonParseException e) { + throw new IOException(e.getCause()); + } + } + + /** + * Json validation: is the parameter a valid json string? + * + * @param jsonString an encoded JSON string + * @return whether this is a valid json string + */ + public static boolean isValidJson(String jsonString) { + try { + parser.parse(jsonString); + } catch (JsonParseException e) { + return false; + } + return true; + } + + /** + * Get the root node inside a JSON string + * + * @param jsonString + * an encoded JSON string + * @return the root node. + * @throws JsonPathNotFoundException if json path is invalid + * @throws IOException if the underlying input source has IO issues during parsing + */ + public static JsonObject getRootNode(String jsonString) throws JsonPathNotFoundException, IOException { + if (StringUtils.isBlank(jsonString)) + throw new JsonPathNotFoundException(); + + return parse(jsonString).getAsJsonObject(); + } + + /** + * Check if there is a Gson JsonObject inside a JSON string matching the + * input path expression. Only exact path match works (e.g. + * JSONUtils.hasChildNode("{\"qst\":3}", "qs") return false). + * + * @param jsonString + * an encoded JSON string + * @param paths + * path fragments + * @return true in case of match; otherwise false. + * @throws IOException if the underlying input source has IO issues during parsing + */ + public static boolean hasChildNode(String jsonString, String... paths) throws IOException { + if (StringUtils.isBlank(jsonString)) + return false; + return getChildNode(jsonString, paths) != null; + } + + /** + * Check if there is a Gson JsonObject inside a JSON string matching the + * input path expression. Only exact path match works (e.g. + * JSONUtils.hasChildNode("{\"qst\":3}", "qs") return false). + * + * @param jsonObject + * a Gson JsonObject + * @param paths + * path fragments + * @return true in case of match; otherwise false. + */ + public static boolean hasChildNode(JsonObject jsonObject, String... paths) { + if (jsonObject == null) + return false; + return getChildNode(jsonObject, paths) != null; + } + + /** + * Search a Gson JsonObject inside a Gson JsonObject matching the + * input path expression + * + * @param jsonElement + * a Gson JsonObject + * @param paths + * path fragments + * @return the matching Gson JsonObject or null in case of no match. + */ + public static JsonElement getChildNode(JsonElement jsonElement, String... paths) { + if (paths == null) { + return null; + } + for (int i = 0; i < paths.length; i++) { + String path = paths[i]; + if (!(jsonElement instanceof JsonObject)) { + return null; + } + JsonObject jsonObject = jsonElement.getAsJsonObject(); + if (!jsonObject.has(path)) + return null; + + jsonElement = jsonObject.get(path); + } + + return jsonElement; + } + + /** + * Search a string inside a JSON string matching the input path expression + * + * @param jsonElement + * a Gson JsonElement + * @param paths + * path fragments + * @return the matching string or null in case of no match. + */ + public static String getTextValue(JsonElement jsonElement, String... paths) throws JsonPathNotFoundException { + + jsonElement = getChildNode(jsonElement, paths); + if (jsonElement != null) { + return jsonElement.getAsString(); + } + + throw new JsonPathNotFoundException(); + } + + /** + * Search a long number inside a JSON string matching the input path + * expression + * + * @param jsonElement + * a Gson JsonElement + * @param paths + * path fragments + * @return the matching long number or null in case of no match. + */ + public static long getLongValue(JsonElement jsonElement, String... paths) throws JsonPathNotFoundException { + + jsonElement = getChildNode(jsonElement, paths); + if (jsonElement != null) { + return jsonElement.getAsLong(); + } + + throw new JsonPathNotFoundException(); + } + + /** + * Search an int number inside a JSON string matching the input path + * expression + * + * @param jsonElement + * a Gson JsonElement + * @param paths + * path fragments + * @return the matching int number or null in case of no match. + */ + public static int getIntValue(JsonElement jsonElement, String... paths) throws JsonPathNotFoundException { + + jsonElement = getChildNode(jsonElement, paths); + if (jsonElement != null) { + return jsonElement.getAsInt(); + } + + throw new JsonPathNotFoundException(); + } + + /** + * Search a string inside a JSON string matching the input path expression + * + * @param jsonString + * an encoded JSON string + * @param paths + * path fragments + * @return the matching string or null in case of no match. + * @throws JsonPathNotFoundException if json path is invalid + * @throws IOException if the underlying input source has IO issues during parsing + */ + public static String getTextValue(String jsonString, String... paths) throws JsonPathNotFoundException, IOException { + if (paths != null && paths.length > 0) { + JsonElement jsonElement = getChildNode(jsonString, paths); + if (jsonElement != null) { + return jsonElement.getAsString(); + } + } + + throw new JsonPathNotFoundException(); + } + + /** + * Search a long number inside a JSON string matching the input path + * expression + * + * @param jsonString + * an encoded JSON string + * @param paths + * path fragments + * @return the matching long number or null in case of no match. + * @throws JsonPathNotFoundException if json path is invalid + * @throws IOException if the underlying input source has IO issues during parsing + */ + public static long getLongValue(String jsonString, String... paths) throws JsonPathNotFoundException, IOException { + if (paths != null && paths.length > 0) { + JsonElement jsonElement = getChildNode(jsonString, paths); + if (jsonElement != null) { + return jsonElement.getAsLong(); + } + } + + throw new JsonPathNotFoundException(); + } + + /** + * Search an int number inside a JSON string matching the input path + * expression + * + * @param jsonString + * an encoded JSON string + * @param paths + * path fragments + * @return the matching int number or null in case of no match. + * @throws JsonPathNotFoundException if json path is invalid + * @throws IOException if the underlying input source has IO issues during parsing + */ + public static int getIntValue(String jsonString, String... paths) throws JsonPathNotFoundException, IOException { + if (paths != null && paths.length > 0) { + JsonElement jsonElement = getChildNode(jsonString, paths); + if (jsonElement != null) { + return jsonElement.getAsInt(); + } + } + + throw new JsonPathNotFoundException(); + } + + private static String handleJsonPathNotFoundException(boolean returnEmptyStringIfMissing) throws JsonPathNotFoundException { + if (returnEmptyStringIfMissing) { + return ""; + } else { + throw new JsonPathNotFoundException(); + } + } + + /** + * Search JSON element from a JSON root. + * If returnEmptyStringIfMissing is true, return "" when json path not found. + * + * @param jsonElement + * a Gson JsonElement + * @param path + * path fragment + * @return the matching string or null in case of no match. + */ + public static String getTextValue(JsonElement jsonElement, String path, boolean returnEmptyStringIfMissing) + throws JsonPathNotFoundException { + try { + return getTextValue(jsonElement, path); + } catch (JsonPathNotFoundException e) { + return handleJsonPathNotFoundException(returnEmptyStringIfMissing); + } + } + + /** + * Search a string inside a JSON string matching the input path expression + * + * @param jsonString + * an encoded JSON string + * @param paths + * path fragments + * @return the matching string or null in case of no match. + * @throws JsonPathNotFoundException if json path is invalid + * @throws IOException if the underlying input source has IO issues during parsing + */ + public static String getTextValue(String jsonString, String paths, boolean returnEmptyStringIfMissing) throws JsonPathNotFoundException, + IOException { + try { + return getTextValue(jsonString, paths); + } catch (JsonPathNotFoundException e) { + return handleJsonPathNotFoundException(returnEmptyStringIfMissing); + } + } + + /** + * Search an array inside a JSON string matching the input path expression and convert each element using a function + * + * @param jsonString an encoded JSON string + * @param function function to parse each element + * @param paths path fragments + * @return an array of values + * @throws JsonPathNotFoundException if json path is invalid + * @throws IOException if the underlying input source has problems + * during parsing + */ + @SuppressWarnings("unchecked") + public static T[] getArrayValue(String jsonString, Function function, String... paths) + throws JsonPathNotFoundException, + IOException { + JsonElement jsonNode = getChildNode(jsonString, paths); + if (jsonNode != null && jsonNode.isJsonArray()) { + JsonArray array = jsonNode.getAsJsonArray(); + Object[] values = new Object[array.size()]; + for (int i = 0; i < array.size(); i++) { + values[i] = function.apply(array.get(i)); + } + return (T[]) values; + } + throw new JsonPathNotFoundException(); + } + + /** + * Search an array inside a JSON string matching the input path expression + * + * @param jsonString an encoded JSON string + * @param paths path fragments + * @throws JsonPathNotFoundException if json path is invalid + * @throws IOException if the underlying input source has problems + * during parsing + */ + @SuppressWarnings("unchecked") + public static JsonArray getArrayValue(String jsonString, String... paths) throws JsonPathNotFoundException, IOException { + JsonElement jsonNode = getChildNode(jsonString, paths); + if (jsonNode != null && jsonNode.isJsonArray()) { + return jsonNode.getAsJsonArray(); + } + throw new JsonPathNotFoundException(); + } + + /** + * Search a double number inside a JSON string matching the input path + * expression + * + * @param jsonString an encoded JSON string + * @param paths path fragments + * @return the matching double number + * @throws JsonPathNotFoundException if json path is invalid + * @throws IOException if the underlying input source has problems + * during parsing + */ + public static double getDoubleValue(String jsonString, String... paths) throws JsonPathNotFoundException, IOException { + JsonElement jsonNode = getChildNode(jsonString, paths); + if (jsonNode != null) { + return jsonNode.getAsDouble(); + } + throw new JsonPathNotFoundException(); + } + + /** + * Search a float number inside a JSON string matching the input path + * expression + * + * @param jsonString an encoded JSON string + * @param paths path fragments + * @return the matching double number + * @throws JsonPathNotFoundException if json path is invalid + * @throws IOException if the underlying input source has problems + * during parsing + */ + public static double getFloatValue(String jsonString, String... paths) throws JsonPathNotFoundException, IOException { + JsonElement jsonNode = getChildNode(jsonString, paths); + if (jsonNode != null) { + return jsonNode.getAsFloat(); + } + throw new JsonPathNotFoundException(); + } + + /** + * Search an int number inside a JSON string matching the input path expression + * + * @param jsonString an encoded JSON string + * @param paths path fragments + * @return list of double + * @throws JsonPathNotFoundException if json path is invalid + * @throws IOException if the underlying input source has problems + * during parsing + */ + public static double[] getDoubleArrayValue(String jsonString, String... paths) throws JsonPathNotFoundException, IOException { + JsonElement jsonNode = getChildNode(jsonString, paths); + if (jsonNode != null && jsonNode.isJsonArray()) { + JsonArray array = jsonNode.getAsJsonArray(); + List values = new ArrayList<>(); + for (int i = 0; i < array.size(); i++) { + values.add(array.get(i).getAsDouble()); + } + return values.stream().mapToDouble(i -> i).toArray(); + } + throw new JsonPathNotFoundException(); + } + + public static List getListValue(String jsonString, Function function, String... paths) + throws JsonPathNotFoundException, + IOException { + JsonElement jsonNode = getChildNode(jsonString, paths); + if (jsonNode != null && jsonNode.isJsonArray()) { + JsonArray array = jsonNode.getAsJsonArray(); + List values = new ArrayList<>(array.size()); + for (int i = 0; i < array.size(); i++) { + values.add(function.apply(array.get(i))); + } + return values; + } + throw new JsonPathNotFoundException(); + } + + public static double getDoubleValue(JsonElement jsonElement, String... paths) throws JsonPathNotFoundException, IOException { + JsonElement jsonNode = getChildNode(jsonElement, paths); + if (jsonNode != null) { + return jsonNode.getAsDouble(); + } + throw new JsonPathNotFoundException(); + } +} diff --git a/src/test/java/test/org/opensearch/ad/util/LabelledAnomalyGenerator.java-e b/src/test/java/test/org/opensearch/ad/util/LabelledAnomalyGenerator.java-e new file mode 100644 index 000000000..f77c135fb --- /dev/null +++ b/src/test/java/test/org/opensearch/ad/util/LabelledAnomalyGenerator.java-e @@ -0,0 +1,105 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package test.org.opensearch.ad.util; + +import static java.lang.Math.PI; + +import java.util.Random; + +import org.joda.time.Instant; + +public class LabelledAnomalyGenerator { + /** + * Generate labbelled multi-dimensional data + * @param num the number of data points + * @param period cosine periods + * @param amplitude cosine amplitude + * @param noise noise amplitude + * @param seed random seed + * @param baseDimension input dimension + * @param useSlope whether to use slope in cosine data + * @param historicalData the number of historical points relative to now + * @param delta point interval + * @param anomalyIndependent whether anomalies in each dimension is generated independently + * @return the labelled data + */ + public static MultiDimDataWithTime getMultiDimData( + int num, + int period, + double amplitude, + double noise, + long seed, + int baseDimension, + boolean useSlope, + int historicalData, + int delta, + boolean anomalyIndependent + ) { + double[][] data = new double[num][]; + long[] timestamps = new long[num]; + double[][] changes = new double[num][]; + long[] changedTimestamps = new long[num]; + Random prg = new Random(seed); + Random noiseprg = new Random(prg.nextLong()); + double[] phase = new double[baseDimension]; + double[] amp = new double[baseDimension]; + double[] slope = new double[baseDimension]; + + for (int i = 0; i < baseDimension; i++) { + phase[i] = prg.nextInt(period); + amp[i] = (1 + 0.2 * prg.nextDouble()) * amplitude; + if (useSlope) { + slope[i] = (0.25 - prg.nextDouble() * 0.5) * amplitude / period; + } + } + + long startEpochMs = Instant.now().getMillis() - historicalData * delta; + for (int i = 0; i < num; i++) { + timestamps[i] = startEpochMs; + startEpochMs += delta; + data[i] = new double[baseDimension]; + double[] newChange = new double[baseDimension]; + // decide whether we should inject anomalies at this point + // If we do this for each dimension, each dimension's anomalies + // are independent and will make it harder for RCF to detect anomalies. + // Doing it in point level will make each dimension's anomalies + // correlated. + if (anomalyIndependent) { + for (int j = 0; j < baseDimension; j++) { + data[i][j] = amp[j] * Math.cos(2 * PI * (i + phase[j]) / period) + slope[j] * i + noise * noiseprg.nextDouble(); + if (noiseprg.nextDouble() < 0.01 && noiseprg.nextDouble() < 0.3) { + double factor = 5 * (1 + noiseprg.nextDouble()); + double change = noiseprg.nextDouble() < 0.5 ? factor * noise : -factor * noise; + data[i][j] += newChange[j] = change; + changedTimestamps[i] = timestamps[i]; + changes[i] = newChange; + } + } + } else { + boolean flag = (noiseprg.nextDouble() < 0.01); + for (int j = 0; j < baseDimension; j++) { + data[i][j] = amp[j] * Math.cos(2 * PI * (i + phase[j]) / period) + slope[j] * i + noise * noiseprg.nextDouble(); + // adding the condition < 0.3 so there is still some variance if all features have an anomaly or not + if (flag && noiseprg.nextDouble() < 0.3) { + double factor = 5 * (1 + noiseprg.nextDouble()); + double change = noiseprg.nextDouble() < 0.5 ? factor * noise : -factor * noise; + data[i][j] += newChange[j] = change; + changedTimestamps[i] = timestamps[i]; + changes[i] = newChange; + } + } + } + } + + return new MultiDimDataWithTime(data, changedTimestamps, changes, timestamps); + } +} diff --git a/src/test/java/test/org/opensearch/ad/util/MLUtil.java-e b/src/test/java/test/org/opensearch/ad/util/MLUtil.java-e new file mode 100644 index 000000000..babae59ef --- /dev/null +++ b/src/test/java/test/org/opensearch/ad/util/MLUtil.java-e @@ -0,0 +1,197 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package test.org.opensearch.ad.util; + +import static java.lang.Math.PI; + +import java.time.Clock; +import java.util.ArrayDeque; +import java.util.HashMap; +import java.util.Map; +import java.util.Queue; +import java.util.Random; +import java.util.stream.IntStream; + +import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ModelManager.ModelType; +import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.common.collect.Tuple; +import org.opensearch.timeseries.model.Entity; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * Cannot use TestUtil inside ML tests since it uses com.carrotsearch.randomizedtesting.RandomizedRunner + * and using it causes Exception in ML tests. + * Most of ML tests are not a subclass if ES base test case. + * + */ +public class MLUtil { + private static Random random = new Random(42); + private static int minSampleSize = AnomalyDetectorSettings.NUM_MIN_SAMPLES; + + private static String randomString(int targetStringLength) { + int leftLimit = 97; // letter 'a' + int rightLimit = 122; // letter 'z' + Random random = new Random(); + + return random + .ints(leftLimit, rightLimit + 1) + .limit(targetStringLength) + .collect(StringBuilder::new, StringBuilder::appendCodePoint, StringBuilder::append) + .toString(); + } + + public static Queue createQueueSamples(int size) { + Queue res = new ArrayDeque<>(); + IntStream.range(0, size).forEach(i -> res.offer(new double[] { random.nextDouble() })); + return res; + } + + public static ModelState randomModelState(RandomModelStateConfig config) { + boolean fullModel = config.getFullModel() != null && config.getFullModel().booleanValue() ? true : false; + float priority = config.getPriority() != null ? config.getPriority() : random.nextFloat(); + String detectorId = config.getId() != null ? config.getId() : randomString(15); + int sampleSize = config.getSampleSize() != null ? config.getSampleSize() : random.nextInt(minSampleSize); + Clock clock = config.getClock() != null ? config.getClock() : Clock.systemUTC(); + + Entity entity = null; + if (config.hasEntityAttributes()) { + Map attributes = new HashMap<>(); + attributes.put("a", "a1"); + attributes.put("b", "b1"); + entity = Entity.createEntityByReordering(attributes); + } else { + entity = Entity.createSingleAttributeEntity("", ""); + } + EntityModel model = null; + if (fullModel) { + model = createNonEmptyModel(detectorId, sampleSize, entity); + } else { + model = createEmptyModel(entity, sampleSize); + } + + return new ModelState<>(model, detectorId, detectorId, ModelType.ENTITY.getName(), clock, priority); + } + + public static EntityModel createEmptyModel(Entity entity, int sampleSize) { + Queue samples = createQueueSamples(sampleSize); + return new EntityModel(entity, samples, null); + } + + public static EntityModel createEmptyModel(Entity entity) { + return createEmptyModel(entity, random.nextInt(minSampleSize)); + } + + public static EntityModel createNonEmptyModel(String detectorId, int sampleSize, Entity entity) { + Queue samples = createQueueSamples(sampleSize); + int numDataPoints = random.nextInt(1000) + AnomalyDetectorSettings.NUM_MIN_SAMPLES; + ThresholdedRandomCutForest trcf = new ThresholdedRandomCutForest( + ThresholdedRandomCutForest + .builder() + .dimensions(1) + .sampleSize(AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE) + .numberOfTrees(AnomalyDetectorSettings.NUM_TREES) + .timeDecay(AnomalyDetectorSettings.TIME_DECAY) + .outputAfter(AnomalyDetectorSettings.NUM_MIN_SAMPLES) + .initialAcceptFraction(0.125d) + .parallelExecutionEnabled(false) + .internalShinglingEnabled(true) + .anomalyRate(1 - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE) + ); + for (int i = 0; i < numDataPoints; i++) { + trcf.process(new double[] { random.nextDouble() }, i); + } + EntityModel entityModel = new EntityModel(entity, samples, trcf); + return entityModel; + } + + public static EntityModel createNonEmptyModel(String detectorId) { + return createNonEmptyModel(detectorId, random.nextInt(minSampleSize), Entity.createSingleAttributeEntity("", "")); + } + + /** + * Generate shingled data + * @param size the number of data points + * @param dimensions the dimensions of a point + * @param seed random seed + * @return the shingled data + */ + public static double[][] generateShingledData(int size, int dimensions, long seed) { + double[][] answer = new double[size][]; + int entryIndex = 0; + boolean filledShingleAtleastOnce = false; + double[] history = new double[dimensions]; + int count = 0; + double[] data = getDataD(size + dimensions - 1, 100, 5, seed); + for (int j = 0; j < size + dimensions - 1; ++j) { + history[entryIndex] = data[j]; + entryIndex = (entryIndex + 1) % dimensions; + if (entryIndex == 0) { + filledShingleAtleastOnce = true; + } + if (filledShingleAtleastOnce) { + answer[count++] = getShinglePoint(history, entryIndex, dimensions); + } + } + return answer; + } + + private static double[] getShinglePoint(double[] recentPointsSeen, int indexOfOldestPoint, int shingleLength) { + double[] shingledPoint = new double[shingleLength]; + int i = 0; + for (int j = 0; j < shingleLength; ++j) { + double point = recentPointsSeen[(j + indexOfOldestPoint) % shingleLength]; + shingledPoint[i++] = point; + + } + return shingledPoint; + } + + static double[] getDataD(int num, double amplitude, double noise, long seed) { + + double[] data = new double[num]; + Random noiseprg = new Random(seed); + for (int i = 0; i < num; i++) { + data[i] = amplitude * Math.cos(2 * PI * (i + 50) / 1000) + noise * noiseprg.nextDouble(); + } + + return data; + } + + /** + * Prepare models and return training samples + * @param inputDimension Input dimension + * @param rcfConfig RCF config + * @return models and return training samples + */ + public static Tuple, ThresholdedRandomCutForest> prepareModel( + int inputDimension, + ThresholdedRandomCutForest.Builder rcfConfig + ) { + Queue samples = new ArrayDeque<>(); + + Random r = new Random(); + ThresholdedRandomCutForest rcf = new ThresholdedRandomCutForest(rcfConfig); + + int trainDataNum = 1000; + + for (int i = 0; i < trainDataNum; i++) { + double[] point = r.ints(inputDimension, 0, 50).asDoubleStream().toArray(); + samples.add(point); + rcf.process(point, 0); + } + + return Tuple.tuple(samples, rcf); + } +} diff --git a/src/test/java/test/org/opensearch/ad/util/MultiDimDataWithTime.java-e b/src/test/java/test/org/opensearch/ad/util/MultiDimDataWithTime.java-e new file mode 100644 index 000000000..6670d53a7 --- /dev/null +++ b/src/test/java/test/org/opensearch/ad/util/MultiDimDataWithTime.java-e @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package test.org.opensearch.ad.util; + +public class MultiDimDataWithTime { + public double[][] data; + public long[] changeTimeStampsMs; + public double[][] changes; + public long[] timestampsMs; + + public MultiDimDataWithTime(double[][] data, long[] changeTimestamps, double[][] changes, long[] timestampsMs) { + this.data = data; + this.changeTimeStampsMs = changeTimestamps; + this.changes = changes; + this.timestampsMs = timestampsMs; + } +} diff --git a/src/test/java/test/org/opensearch/ad/util/RandomModelStateConfig.java-e b/src/test/java/test/org/opensearch/ad/util/RandomModelStateConfig.java-e new file mode 100644 index 000000000..25a2da1bd --- /dev/null +++ b/src/test/java/test/org/opensearch/ad/util/RandomModelStateConfig.java-e @@ -0,0 +1,100 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package test.org.opensearch.ad.util; + +import java.time.Clock; + +public class RandomModelStateConfig { + private final Boolean fullModel; + private final Float priority; + private final String detectorId; + private final Integer sampleSize; + private final Clock clock; + private final Boolean entityAttributes; + + private RandomModelStateConfig(Builder builder) { + this.fullModel = builder.fullModel; + this.priority = builder.priority; + this.detectorId = builder.detectorId; + this.sampleSize = builder.sampleSize; + this.clock = builder.clock; + this.entityAttributes = builder.entityAttributes; + } + + public Boolean getFullModel() { + return fullModel; + } + + public Float getPriority() { + return priority; + } + + public String getId() { + return detectorId; + } + + public Integer getSampleSize() { + return sampleSize; + } + + public Clock getClock() { + return clock; + } + + public Boolean hasEntityAttributes() { + return entityAttributes; + } + + public static class Builder { + private Boolean fullModel = null; + private Float priority = null; + private String detectorId = null; + private Integer sampleSize = null; + private Clock clock = null; + private Boolean entityAttributes = false; + + public Builder fullModel(boolean fullModel) { + this.fullModel = fullModel; + return this; + } + + public Builder priority(float priority) { + this.priority = priority; + return this; + } + + public Builder detectorId(String detectorId) { + this.detectorId = detectorId; + return this; + } + + public Builder sampleSize(int sampleSize) { + this.sampleSize = sampleSize; + return this; + } + + public Builder clock(Clock clock) { + this.clock = clock; + return this; + } + + public Builder entityAttributes(Boolean entityAttributes) { + this.entityAttributes = entityAttributes; + return this; + } + + public RandomModelStateConfig build() { + RandomModelStateConfig config = new RandomModelStateConfig(this); + return config; + } + } +}