Skip to content

Commit

Permalink
Add rule validation in AnomalyDetector constructor
Browse files Browse the repository at this point in the history
This commit introduces rule validation within the AnomalyDetector constructor. Any validation errors are now propagated and displayed on the frontend to ensure immediate feedback.

Testing:
* Verified that validation errors are properly propagated and shown on the frontend.
* Added UTs to cover the new validation logic.

Signed-off-by: Kaituo Li <kaituo@amazon.com>
  • Loading branch information
kaituo committed Oct 18, 2024
1 parent 2ab6dc7 commit 4bc262c
Show file tree
Hide file tree
Showing 12 changed files with 664 additions and 81 deletions.
3 changes: 0 additions & 3 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -700,9 +700,6 @@ List<String> jacocoExclusions = [

// TODO: add test coverage (kaituo)
'org.opensearch.forecast.*',
'org.opensearch.timeseries.transport.SuggestConfigParamResponse',
'org.opensearch.timeseries.transport.SuggestConfigParamRequest',
'org.opensearch.timeseries.ml.MemoryAwareConcurrentHashmap',
'org.opensearch.timeseries.transport.ResultBulkTransportAction',
'org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler',
'org.opensearch.timeseries.transport.handler.ResultIndexingHandler',
Expand Down
26 changes: 1 addition & 25 deletions src/main/java/org/opensearch/ad/ml/ADModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
Expand All @@ -42,7 +41,6 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.timeseries.AnalysisModelSize;
import org.opensearch.timeseries.MemoryTracker;
import org.opensearch.timeseries.common.exception.ResourceNotFoundException;
import org.opensearch.timeseries.common.exception.TimeSeriesException;
Expand All @@ -52,7 +50,6 @@
import org.opensearch.timeseries.ml.ModelColdStart;
import org.opensearch.timeseries.ml.ModelManager;
import org.opensearch.timeseries.ml.ModelState;
import org.opensearch.timeseries.ml.SingleStreamModelIdMapper;
import org.opensearch.timeseries.model.Config;
import org.opensearch.timeseries.settings.TimeSeriesSettings;
import org.opensearch.timeseries.util.DateUtils;
Expand All @@ -69,9 +66,7 @@
* A facade managing ML operations and models.
*/
public class ADModelManager extends
ModelManager<ThresholdedRandomCutForest, AnomalyResult, ThresholdingResult, ADIndex, ADIndexManagement, ADCheckpointDao, ADCheckpointWriteWorker, ADColdStart>
implements
AnalysisModelSize {
ModelManager<ThresholdedRandomCutForest, AnomalyResult, ThresholdingResult, ADIndex, ADIndexManagement, ADCheckpointDao, ADCheckpointWriteWorker, ADColdStart> {
protected static final String ENTITY_SAMPLE = "sp";
protected static final String ENTITY_RCF = "rcf";
protected static final String ENTITY_THRESHOLD = "th";
Expand Down Expand Up @@ -594,25 +589,6 @@ public List<ThresholdingResult> getPreviewResults(Features features, AnomalyDete
}).collect(Collectors.toList());
}

/**
* Get all RCF partition's size corresponding to a detector. Thresholding models' size is a constant since they are small in size (KB).
* @param detectorId detector id
* @return a map of model id to its memory size
*/
@Override
public Map<String, Long> getModelSize(String detectorId) {
Map<String, Long> res = new HashMap<>();
res.putAll(forests.getModelSize(detectorId));
thresholds
.entrySet()
.stream()
.filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(detectorId))
.forEach(entry -> {
res.put(entry.getKey(), (long) memoryTracker.getThresholdModelBytes());
});
return res;
}

/**
* Get a RCF model's total updates.
* @param modelId the RCF model's id
Expand Down
116 changes: 116 additions & 0 deletions src/main/java/org/opensearch/ad/model/AnomalyDetector.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -229,6 +230,8 @@ public AnomalyDetector(
issueType = ValidationIssueType.CATEGORY;
}

validateRules(features, rules);

checkAndThrowValidationErrors(ValidationAspect.DETECTOR);

this.detectorType = isHC(categoryFields) ? MULTI_ENTITY.name() : SINGLE_ENTITY.name();
Expand Down Expand Up @@ -720,4 +723,117 @@ private static Boolean onlyParseBooleanValue(XContentParser parser) throws IOExc
}
return null;
}

/**
* Validates each condition in the list of rules against the list of features.
* Checks that:
* - The feature name exists in the list of features.
* - The related feature is enabled.
* - The value is not NaN and is positive.
*
* @param features The list of available features. Must not be null.
* @param rules The list of rules containing conditions to validate. Can be null.
*/
private void validateRules(List<Feature> features, List<Rule> rules) {
// Null check for rules
if (rules == null || rules.isEmpty()) {
return; // No suppression rules to validate; consider as valid
}

// Null check for features
if (features == null) {
// Cannot proceed with validation if features are null but rules are not null
this.errorMessage = "Suppression Rule Error: Features are not defined while suppression rules are provided.";
this.issueType = ValidationIssueType.RULE;
return;
}

// Create a map of feature names to their enabled status for quick lookup
Map<String, Boolean> featureEnabledMap = new HashMap<>();
for (Feature feature : features) {
if (feature != null && feature.getName() != null) {
featureEnabledMap.put(feature.getName(), feature.getEnabled());
}
}

// Iterate over each rule
for (Rule rule : rules) {
if (rule == null || rule.getConditions() == null) {
// Invalid rule or conditions list is null
this.errorMessage = "Suppression Rule Error: A suppression rule or its conditions are not properly defined.";
this.issueType = ValidationIssueType.RULE;
return;
}

// Iterate over each condition in the rule
for (Condition condition : rule.getConditions()) {
if (condition == null) {
// Invalid condition
this.errorMessage = "Suppression Rule Error: A condition within a suppression rule is not properly defined.";
this.issueType = ValidationIssueType.RULE;
return;
}

String featureName = condition.getFeatureName();

// Check if the feature name is null
if (featureName == null) {
// Feature name is required
this.errorMessage = "Suppression Rule Error: A condition is missing the feature name.";
this.issueType = ValidationIssueType.RULE;
return;
}

// Check if the feature exists
if (!featureEnabledMap.containsKey(featureName)) {
// Feature does not exist
this.errorMessage = "Suppression Rule Error: Feature \""
+ featureName
+ "\" specified in a suppression rule does not exist.";
this.issueType = ValidationIssueType.RULE;
return;
}

// Check if the feature is enabled
if (!featureEnabledMap.get(featureName)) {
// Feature is not enabled
this.errorMessage = "Suppression Rule Error: Feature \""
+ featureName
+ "\" specified in a suppression rule is not enabled.";
this.issueType = ValidationIssueType.RULE;
return;
}

// other threshold types may not have value operand
ThresholdType thresholdType = condition.getThresholdType();
if (thresholdType == ThresholdType.ACTUAL_OVER_EXPECTED_MARGIN
|| thresholdType == ThresholdType.EXPECTED_OVER_ACTUAL_MARGIN
|| thresholdType == ThresholdType.ACTUAL_OVER_EXPECTED_RATIO
|| thresholdType == ThresholdType.EXPECTED_OVER_ACTUAL_RATIO) {
// Check if the value is not NaN
double value = condition.getValue();
if (Double.isNaN(value)) {
// Value is NaN
this.errorMessage = "Suppression Rule Error: The threshold value for feature \""
+ featureName
+ "\" is not a valid number.";
this.issueType = ValidationIssueType.RULE;
return;
}

// Check if the value is positive
if (value <= 0) {
// Value is not positive
this.errorMessage = "Suppression Rule Error: The threshold value for feature \""
+ featureName
+ "\" must be a positive number.";
this.issueType = ValidationIssueType.RULE;
return;
}
}
}
}

// All checks passed
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@

package org.opensearch.timeseries.ml;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;

import org.opensearch.timeseries.MemoryTracker;
Expand Down Expand Up @@ -55,48 +52,4 @@ public ModelState<RCFModelType> put(String key, ModelState<RCFModelType> value)
}
return previousAssociatedState;
}

/**
* Gets all of a config's model sizes hosted on a node
*
* @param configId config Id
* @return a map of model id to its memory size
*/
public Map<String, Long> getModelSize(String configId) {
Map<String, Long> res = new HashMap<>();
super.entrySet()
.stream()
.filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(configId))
.forEach(entry -> {
Optional<RCFModelType> modelOptional = entry.getValue().getModel();
if (modelOptional.isPresent()) {
res.put(entry.getKey(), memoryTracker.estimateTRCFModelSize(modelOptional.get()));
}
});
return res;
}

/**
* Checks if a model exists for the given config.
* @param configId Config Id
* @return `true` if the model exists, `false` otherwise.
*/
public boolean doesModelExist(String configId) {
return super.entrySet()
.stream()
.filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(configId))
.anyMatch(n -> true);
}

public boolean hostIfPossible(String modelId, ModelState<RCFModelType> toUpdate) {
return Optional
.ofNullable(toUpdate)
.filter(state -> state.getModel().isPresent())
.filter(state -> memoryTracker.isHostingAllowed(modelId, state.getModel().get()))
.map(state -> {
super.put(modelId, toUpdate);
return true;
})
.orElse(false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ public enum ValidationIssueType implements Name {
SUBAGGREGATION(SearchTopForecastResultRequest.SUBAGGREGATIONS_FIELD),
RECENCY_EMPHASIS(Config.RECENCY_EMPHASIS_FIELD),
DESCRIPTION(Config.DESCRIPTION_FIELD),
HISTORY(Config.HISTORY_INTERVAL_FIELD);
HISTORY(Config.HISTORY_INTERVAL_FIELD),
RULE(AnomalyDetector.RULES_FIELD);

private String name;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ public class SuggestConfigParamRequest extends ActionRequest {
public SuggestConfigParamRequest(StreamInput in) throws IOException {
super(in);
context = in.readEnum(AnalysisType.class);
if (context.isAD()) {
if (getContext().isAD()) {
config = new AnomalyDetector(in);
} else if (context.isForecast()) {
} else if (getContext().isForecast()) {
config = new Forecaster(in);
} else {
throw new UnsupportedOperationException("This method is not supported");
Expand All @@ -55,7 +55,7 @@ public SuggestConfigParamRequest(AnalysisType context, Config config, String par
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeEnum(context);
out.writeEnum(getContext());
config.writeTo(out);
out.writeString(param);
out.writeTimeValue(requestTimeout);
Expand All @@ -77,4 +77,8 @@ public String getParam() {
public TimeValue getRequestTimeout() {
return requestTimeout;
}

public AnalysisType getContext() {
return context;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public void testRule() throws Exception {
minPrecision.put("Scottsdale", 0.5);
Map<String, Double> minRecall = new HashMap<>();
minRecall.put("Phoenix", 0.9);
minRecall.put("Scottsdale", 0.6);
minRecall.put("Scottsdale", 0.3);
verifyRule("rule", 10, minPrecision.size(), 1500, minPrecision, minRecall, 20);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public void testRule() throws Exception {
minPrecision.put("Scottsdale", 0.5);
Map<String, Double> minRecall = new HashMap<>();
minRecall.put("Phoenix", 0.9);
minRecall.put("Scottsdale", 0.6);
minRecall.put("Scottsdale", 0.3);
verifyRule("rule", 10, minPrecision.size(), 1500, minPrecision, minRecall, 20);
}
}
Expand Down
Loading

0 comments on commit 4bc262c

Please sign in to comment.