Skip to content

Commit

Permalink
add mitre attack based auto-correlations support in correlation engine (
Browse files Browse the repository at this point in the history
#532)

Signed-off-by: Subhobrata Dey <sbcd90@gmail.com>
  • Loading branch information
sbcd90 authored Sep 6, 2023
1 parent 1cb5ddc commit 32d5aa1
Show file tree
Hide file tree
Showing 5 changed files with 9,922 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.apache.lucene.search.join.ScoreMode;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.cluster.routing.Preference;
import org.opensearch.commons.alerting.model.DocLevelQuery;
import org.opensearch.core.action.ActionListener;
import org.opensearch.action.search.MultiSearchRequest;
import org.opensearch.action.search.MultiSearchResponse;
Expand All @@ -23,26 +24,32 @@
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.index.query.NestedQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.securityanalytics.config.monitors.DetectorMonitorConfig;
import org.opensearch.securityanalytics.logtype.LogTypeService;
import org.opensearch.securityanalytics.model.CorrelationQuery;
import org.opensearch.securityanalytics.model.CorrelationRule;
import org.opensearch.securityanalytics.model.Detector;
import org.opensearch.securityanalytics.transport.TransportCorrelateFindingAction;
import org.opensearch.securityanalytics.util.AutoCorrelationsRepo;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;


Expand All @@ -58,18 +65,147 @@ public class JoinEngine {

private final TransportCorrelateFindingAction.AsyncCorrelateFindingAction correlateFindingAction;

private final LogTypeService logTypeService;

private static final Logger log = LogManager.getLogger(JoinEngine.class);

public JoinEngine(Client client, PublishFindingsRequest request, NamedXContentRegistry xContentRegistry,
long corrTimeWindow, TransportCorrelateFindingAction.AsyncCorrelateFindingAction correlateFindingAction) {
long corrTimeWindow, TransportCorrelateFindingAction.AsyncCorrelateFindingAction correlateFindingAction,
LogTypeService logTypeService) {
this.client = client;
this.request = request;
this.xContentRegistry = xContentRegistry;
this.corrTimeWindow = corrTimeWindow;
this.correlateFindingAction = correlateFindingAction;
this.logTypeService = logTypeService;
}

public void onSearchDetectorResponse(Detector detector, Finding finding) {
try {
generateAutoCorrelations(detector, finding);
} catch (IOException ex) {
correlateFindingAction.onFailures(ex);
}
}

@SuppressWarnings("unchecked")
private void generateAutoCorrelations(Detector detector, Finding finding) throws IOException {
Map<String, Set<String>> autoCorrelations = AutoCorrelationsRepo.autoCorrelationsAsMap();
long findingTimestamp = finding.getTimestamp().toEpochMilli();

Set<String> tags = new HashSet<>();
for (DocLevelQuery query : finding.getDocLevelQueries()) {
tags.addAll(query.getTags().stream().filter(tag -> tag.startsWith("attack.")).collect(Collectors.toList()));
}
Set<String> validIntrusionSets = AutoCorrelationsRepo.validIntrusionSets(autoCorrelations, tags);

MatchQueryBuilder queryBuilder = QueryBuilders.matchQuery("source", "Sigma");

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(queryBuilder);
searchSourceBuilder.size(100);

SearchRequest request = new SearchRequest();
request.source(searchSourceBuilder);
logTypeService.searchLogTypes(request, new ActionListener<>() {
@Override
public void onResponse(SearchResponse response) {
MultiSearchRequest mSearchRequest = new MultiSearchRequest();
SearchHit[] logTypes = response.getHits().getHits();
List<String> logTypeNames = new ArrayList<>();
for (SearchHit logType: logTypes) {
String logTypeName = logType.getSourceAsMap().get("name").toString();
logTypeNames.add(logTypeName);

RangeQueryBuilder queryBuilder = QueryBuilders.rangeQuery("timestamp")
.gte(findingTimestamp - corrTimeWindow)
.lte(findingTimestamp + corrTimeWindow);

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(queryBuilder);
searchSourceBuilder.size(10000);
searchSourceBuilder.fetchField("queries");
SearchRequest searchRequest = new SearchRequest();
searchRequest.indices(DetectorMonitorConfig.getAllFindingsIndicesPattern(logTypeName));
searchRequest.source(searchSourceBuilder);
searchRequest.preference(Preference.PRIMARY_FIRST.type());
mSearchRequest.add(searchRequest);
}

if (!mSearchRequest.requests().isEmpty()) {
client.multiSearch(mSearchRequest, new ActionListener<>() {
@Override
public void onResponse(MultiSearchResponse items) {
MultiSearchResponse.Item[] responses = items.getResponses();

Map<String, List<String>> autoCorrelationsMap = new HashMap<>();
int idx = 0;
for (MultiSearchResponse.Item response : responses) {
if (response.isFailure()) {
log.info(response.getFailureMessage());
continue;
}
String logTypeName = logTypeNames.get(idx);

SearchHit[] findings = response.getResponse().getHits().getHits();

for (SearchHit foundFinding : findings) {
if (!foundFinding.getId().equals(finding.getId())) {
Set<String> findingTags = new HashSet<>();
List<Map<String, Object>> queries = (List<Map<String, Object>>) foundFinding.getSourceAsMap().get("queries");
for (Map<String, Object> query : queries) {
List<String> queryTags = (List<String>) query.get("tags");
findingTags.addAll(queryTags.stream().filter(queryTag -> queryTag.startsWith("attack.")).collect(Collectors.toList()));
}

boolean canCorrelate = false;
for (String tag: tags) {
if (findingTags.contains(tag)) {
canCorrelate = true;
break;
}
}

Set<String> foundIntrusionSets = AutoCorrelationsRepo.validIntrusionSets(autoCorrelations, findingTags);
for (String validIntrusionSet: validIntrusionSets) {
if (foundIntrusionSets.contains(validIntrusionSet)) {
canCorrelate = true;
break;
}
}

if (canCorrelate) {
if (autoCorrelationsMap.containsKey(logTypeName)) {
autoCorrelationsMap.get(logTypeName).add(foundFinding.getId());
} else {
List<String> autoCorrelatedFindings = new ArrayList<>();
autoCorrelatedFindings.add(foundFinding.getId());
autoCorrelationsMap.put(logTypeName, autoCorrelatedFindings);
}
}
}
}
++idx;
}
onAutoCorrelations(detector, finding, autoCorrelationsMap);
}

@Override
public void onFailure(Exception e) {
correlateFindingAction.onFailures(e);
}
});
}
}

@Override
public void onFailure(Exception e) {
correlateFindingAction.onFailures(e);
}
});
}

private void onAutoCorrelations(Detector detector, Finding finding, Map<String, List<String>> autoCorrelations) {
String detectorType = detector.getDetectorType().toLowerCase(Locale.ROOT);
List<String> indices = detector.getInputs().get(0).getIndices();
List<String> relatedDocIds = finding.getCorrelatedDocIds();
Expand Down Expand Up @@ -113,20 +249,20 @@ public void onResponse(SearchResponse response) {
}
}

getValidDocuments(detectorType, indices, correlationRules, relatedDocIds);
getValidDocuments(detectorType, indices, correlationRules, relatedDocIds, autoCorrelations);
}

@Override
public void onFailure(Exception e) {
correlateFindingAction.onFailures(e);
getValidDocuments(detectorType, indices, List.of(), List.of(), autoCorrelations);
}
});
}

/**
* this method checks if the finding to be correlated has valid related docs(or not) which match join criteria.
*/
private void getValidDocuments(String detectorType, List<String> indices, List<CorrelationRule> correlationRules, List<String> relatedDocIds) {
private void getValidDocuments(String detectorType, List<String> indices, List<CorrelationRule> correlationRules, List<String> relatedDocIds, Map<String, List<String>> autoCorrelations) {
MultiSearchRequest mSearchRequest = new MultiSearchRequest();
List<CorrelationRule> validCorrelationRules = new ArrayList<>();

Expand Down Expand Up @@ -189,7 +325,9 @@ public void onResponse(MultiSearchResponse items) {
}
}
searchFindingsByTimestamp(detectorType, categoryToQueriesMap,
filteredCorrelationRules.stream().map(CorrelationRule::getId).collect(Collectors.toList()));
filteredCorrelationRules.stream().map(CorrelationRule::getId).collect(Collectors.toList()),
autoCorrelations
);
}

@Override
Expand All @@ -198,15 +336,19 @@ public void onFailure(Exception e) {
}
});
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), List.of());
if (!autoCorrelations.isEmpty()) {
correlateFindingAction.getTimestampFeature(detectorType, autoCorrelations, null, List.of());
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), List.of());
}
}
}

/**
* this method searches for parent findings given the log category & correlation time window & collects all related docs
* for them.
*/
private void searchFindingsByTimestamp(String detectorType, Map<String, List<CorrelationQuery>> categoryToQueriesMap, List<String> correlationRules) {
private void searchFindingsByTimestamp(String detectorType, Map<String, List<CorrelationQuery>> categoryToQueriesMap, List<String> correlationRules, Map<String, List<String>> autoCorrelations) {
long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli();
MultiSearchRequest mSearchRequest = new MultiSearchRequest();
List<Pair<String, List<CorrelationQuery>>> categoryToQueriesPairs = new ArrayList<>();
Expand Down Expand Up @@ -260,7 +402,7 @@ public void onResponse(MultiSearchResponse items) {
relatedDocIds));
++idx;
}
searchDocsWithFilterKeys(detectorType, relatedDocsMap, correlationRules);
searchDocsWithFilterKeys(detectorType, relatedDocsMap, correlationRules, autoCorrelations);
}

@Override
Expand All @@ -269,14 +411,18 @@ public void onFailure(Exception e) {
}
});
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules);
if (!autoCorrelations.isEmpty()) {
correlateFindingAction.getTimestampFeature(detectorType, autoCorrelations, null, List.of());
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules);
}
}
}

/**
* Given the related docs from parent findings, this method filters only those related docs which match parent join criteria.
*/
private void searchDocsWithFilterKeys(String detectorType, Map<String, DocSearchCriteria> relatedDocsMap, List<String> correlationRules) {
private void searchDocsWithFilterKeys(String detectorType, Map<String, DocSearchCriteria> relatedDocsMap, List<String> correlationRules, Map<String, List<String>> autoCorrelations) {
MultiSearchRequest mSearchRequest = new MultiSearchRequest();
List<String> categories = new ArrayList<>();

Expand Down Expand Up @@ -324,7 +470,7 @@ public void onResponse(MultiSearchResponse items) {
filteredRelatedDocIds.put(categories.get(idx), docIds);
++idx;
}
getCorrelatedFindings(detectorType, filteredRelatedDocIds, correlationRules);
getCorrelatedFindings(detectorType, filteredRelatedDocIds, correlationRules, autoCorrelations);
}

@Override
Expand All @@ -333,15 +479,19 @@ public void onFailure(Exception e) {
}
});
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules);
if (!autoCorrelations.isEmpty()) {
correlateFindingAction.getTimestampFeature(detectorType, autoCorrelations, null, List.of());
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules);
}
}
}

/**
* Given the filtered related docs of the parent findings, this method gets the actual filtered parent findings for
* the finding to be correlated.
*/
private void getCorrelatedFindings(String detectorType, Map<String, List<String>> filteredRelatedDocIds, List<String> correlationRules) {
private void getCorrelatedFindings(String detectorType, Map<String, List<String>> filteredRelatedDocIds, List<String> correlationRules, Map<String, List<String>> autoCorrelations) {
long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli();
MultiSearchRequest mSearchRequest = new MultiSearchRequest();
List<String> categories = new ArrayList<>();
Expand Down Expand Up @@ -397,6 +547,16 @@ public void onResponse(MultiSearchResponse items) {
}
++idx;
}

for (Map.Entry<String, List<String>> autoCorrelation: autoCorrelations.entrySet()) {
if (correlatedFindings.containsKey(autoCorrelation.getKey())) {
Set<String> alreadyCorrelatedFindings = new HashSet<>(correlatedFindings.get(autoCorrelation.getKey()));
alreadyCorrelatedFindings.addAll(autoCorrelation.getValue());
correlatedFindings.put(autoCorrelation.getKey(), new ArrayList<>(alreadyCorrelatedFindings));
} else {
correlatedFindings.put(autoCorrelation.getKey(), autoCorrelation.getValue());
}
}
correlateFindingAction.initCorrelationIndex(detectorType, correlatedFindings, correlationRules);
}

Expand All @@ -406,7 +566,11 @@ public void onFailure(Exception e) {
}
});
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules);
if (!autoCorrelations.isEmpty()) {
correlateFindingAction.getTimestampFeature(detectorType, autoCorrelations, null, List.of());
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ public class TransportCorrelateFindingAction extends HandledTransportAction<Acti

private final CorrelationIndices correlationIndices;

private final LogTypeService logTypeService;

private final ClusterService clusterService;

private final Settings settings;
Expand All @@ -100,6 +102,7 @@ public TransportCorrelateFindingAction(TransportService transportService,
NamedXContentRegistry xContentRegistry,
DetectorIndices detectorIndices,
CorrelationIndices correlationIndices,
LogTypeService logTypeService,
ClusterService clusterService,
Settings settings,
ActionFilters actionFilters) {
Expand All @@ -108,6 +111,7 @@ public TransportCorrelateFindingAction(TransportService transportService,
this.xContentRegistry = xContentRegistry;
this.detectorIndices = detectorIndices;
this.correlationIndices = correlationIndices;
this.logTypeService = logTypeService;
this.clusterService = clusterService;
this.settings = settings;
this.threadPool = this.detectorIndices.getThreadPool();
Expand Down Expand Up @@ -186,7 +190,7 @@ public class AsyncCorrelateFindingAction {

this.response =new AtomicReference<>();

this.joinEngine = new JoinEngine(client, request, xContentRegistry, corrTimeWindow, this);
this.joinEngine = new JoinEngine(client, request, xContentRegistry, corrTimeWindow, this, logTypeService);
this.vectorEmbeddingsEngine = new VectorEmbeddingsEngine(client, indexTimeout, corrTimeWindow, this);
}

Expand Down
Loading

0 comments on commit 32d5aa1

Please sign in to comment.