diff --git a/build.gradle b/build.gradle index 2e16c6b70..2a958f0b6 100644 --- a/build.gradle +++ b/build.gradle @@ -158,6 +158,8 @@ dependencies { api "org.opensearch:common-utils:${common_utils_version}@jar" api "org.opensearch.client:opensearch-rest-client:${opensearch_version}" implementation "org.jetbrains.kotlin:kotlin-stdlib:${kotlin_version}" + compileOnly "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}" + implementation "org.apache.commons:commons-csv:1.10.0" // Needed for integ tests zipArchive group: 'org.opensearch.plugin', name:'alerting', version: "${opensearch_build}" diff --git a/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java b/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java index 2c60321df..4da111975 100644 --- a/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java +++ b/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java @@ -4,20 +4,13 @@ */ package org.opensearch.securityanalytics; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Optional; +import java.util.*; import java.util.function.Supplier; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.cluster.routing.Preference; import org.opensearch.core.action.ActionListener; import org.opensearch.action.ActionRequest; import org.opensearch.core.action.ActionResponse; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.node.DiscoveryNode; @@ -38,18 +31,12 @@ import org.opensearch.index.codec.CodecServiceFactory; import org.opensearch.index.engine.EngineFactory; import org.opensearch.index.mapper.Mapper; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.plugins.ActionPlugin; -import org.opensearch.plugins.ClusterPlugin; -import org.opensearch.plugins.EnginePlugin; -import org.opensearch.plugins.MapperPlugin; -import org.opensearch.plugins.Plugin; -import org.opensearch.plugins.SearchPlugin; +import org.opensearch.indices.SystemIndexDescriptor; +import org.opensearch.plugins.*; import org.opensearch.repositories.RepositoriesService; import org.opensearch.rest.RestController; import org.opensearch.rest.RestHandler; import org.opensearch.script.ScriptService; -import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.securityanalytics.action.*; import org.opensearch.securityanalytics.correlation.index.codec.CorrelationCodecService; import org.opensearch.securityanalytics.correlation.index.mapper.CorrelationVectorFieldMapper; @@ -60,7 +47,15 @@ import org.opensearch.securityanalytics.mapper.IndexTemplateManager; import org.opensearch.securityanalytics.mapper.MapperService; import org.opensearch.securityanalytics.model.CustomLogType; +import org.opensearch.securityanalytics.model.ThreatIntelFeedData; import org.opensearch.securityanalytics.resthandler.*; +import org.opensearch.securityanalytics.threatIntel.DetectorThreatIntelService; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelFeedDataService; +import org.opensearch.securityanalytics.threatIntel.common.ThreatIntelExecutor; +import org.opensearch.securityanalytics.threatIntel.common.ThreatIntelLockService; +import org.opensearch.securityanalytics.threatIntel.dao.DatasourceDao; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.DatasourceRunner; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.DatasourceUpdateService; import org.opensearch.securityanalytics.transport.*; import org.opensearch.securityanalytics.model.Rule; import org.opensearch.securityanalytics.model.Detector; @@ -72,10 +67,13 @@ import org.opensearch.securityanalytics.util.DetectorIndices; import org.opensearch.securityanalytics.util.RuleIndices; import org.opensearch.securityanalytics.util.RuleTopicIndices; +import org.opensearch.threadpool.ExecutorBuilder; import org.opensearch.threadpool.ThreadPool; import org.opensearch.watcher.ResourceWatcherService; -public class SecurityAnalyticsPlugin extends Plugin implements ActionPlugin, MapperPlugin, SearchPlugin, EnginePlugin, ClusterPlugin { +import static org.opensearch.securityanalytics.threatIntel.jobscheduler.Datasource.THREAT_INTEL_DATA_INDEX_NAME_PREFIX; + +public class SecurityAnalyticsPlugin extends Plugin implements ActionPlugin, MapperPlugin, SearchPlugin, EnginePlugin, ClusterPlugin, SystemIndexPlugin { private static final Logger log = LogManager.getLogger(SecurityAnalyticsPlugin.class); @@ -116,6 +114,22 @@ public class SecurityAnalyticsPlugin extends Plugin implements ActionPlugin, Map private Client client; +// private DatasourceDao datasourceDao; + +// private ThreatIntelFeedDataService threatIntelFeedDataService; + + @Override + public Collection getSystemIndexDescriptors(Settings settings){ + return List.of(new SystemIndexDescriptor(THREAT_INTEL_DATA_INDEX_NAME_PREFIX, "System index used for threat intel data")); + } + + @Override + public List> getExecutorBuilders(Settings settings) { + List> executorBuilders = new ArrayList<>(); + executorBuilders.add(ThreatIntelExecutor.executorBuilder(settings)); + return executorBuilders; + } + @Override public Collection createComponents(Client client, ClusterService clusterService, @@ -128,6 +142,7 @@ public Collection createComponents(Client client, NamedWriteableRegistry namedWriteableRegistry, IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier) { + builtinLogTypeLoader = new BuiltinLogTypeLoader(); logTypeService = new LogTypeService(client, clusterService, xContentRegistry, builtinLogTypeLoader); detectorIndices = new DetectorIndices(client.admin(), clusterService, threadPool); @@ -138,11 +153,22 @@ public Collection createComponents(Client client, mapperService = new MapperService(client, clusterService, indexNameExpressionResolver, indexTemplateManager, logTypeService); ruleIndices = new RuleIndices(logTypeService, client, clusterService, threadPool); correlationRuleIndices = new CorrelationRuleIndices(client, clusterService); + ThreatIntelFeedDataService threatIntelFeedDataService = new ThreatIntelFeedDataService(clusterService.state(), clusterService, client, indexNameExpressionResolver, xContentRegistry); + DetectorThreatIntelService detectorThreatIntelService = new DetectorThreatIntelService(threatIntelFeedDataService); + DatasourceDao datasourceDao = new DatasourceDao(client, clusterService); + this.client = client; + DatasourceUpdateService datasourceUpdateService = new DatasourceUpdateService(clusterService, datasourceDao, threatIntelFeedDataService); + ThreatIntelExecutor threatIntelExecutor = new ThreatIntelExecutor(threadPool); + ThreatIntelLockService threatIntelLockService = new ThreatIntelLockService(clusterService, client); + + DatasourceRunner.getJobRunnerInstance().initialize(clusterService,datasourceUpdateService, datasourceDao, threatIntelExecutor, threatIntelLockService); + return List.of( detectorIndices, correlationIndices, correlationRuleIndices, ruleTopicIndices, customLogTypeIndices, ruleIndices, - mapperService, indexTemplateManager, builtinLogTypeLoader + mapperService, indexTemplateManager, builtinLogTypeLoader, threatIntelFeedDataService, detectorThreatIntelService, + datasourceUpdateService, datasourceDao, threatIntelExecutor, threatIntelLockService ); } @@ -193,7 +219,8 @@ public List getNamedXContent() { Detector.XCONTENT_REGISTRY, DetectorInput.XCONTENT_REGISTRY, Rule.XCONTENT_REGISTRY, - CustomLogType.XCONTENT_REGISTRY + CustomLogType.XCONTENT_REGISTRY, + ThreatIntelFeedData.XCONTENT_REGISTRY ); } @@ -243,7 +270,12 @@ public List> getSettings() { SecurityAnalyticsSettings.IS_CORRELATION_INDEX_SETTING, SecurityAnalyticsSettings.CORRELATION_TIME_WINDOW, SecurityAnalyticsSettings.DEFAULT_MAPPING_SCHEMA, - SecurityAnalyticsSettings.ENABLE_WORKFLOW_USAGE + SecurityAnalyticsSettings.ENABLE_WORKFLOW_USAGE, + SecurityAnalyticsSettings.DATASOURCE_ENDPOINT, + SecurityAnalyticsSettings.DATASOURCE_UPDATE_INTERVAL, + SecurityAnalyticsSettings.BATCH_SIZE, + SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT, + SecurityAnalyticsSettings.CACHE_SIZE ); } @@ -292,5 +324,5 @@ public void onFailure(Exception e) { log.warn("Failed to initialize LogType config index and builtin log types"); } }); - } + } } \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/action/GetDetectorResponse.java b/src/main/java/org/opensearch/securityanalytics/action/GetDetectorResponse.java index 3e4fc68d1..0d700b88c 100644 --- a/src/main/java/org/opensearch/securityanalytics/action/GetDetectorResponse.java +++ b/src/main/java/org/opensearch/securityanalytics/action/GetDetectorResponse.java @@ -68,6 +68,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws .field(Detector.INPUTS_FIELD, detector.getInputs()) .field(Detector.LAST_UPDATE_TIME_FIELD, detector.getLastUpdateTime()) .field(Detector.ENABLED_TIME_FIELD, detector.getEnabledTime()) + .field(Detector.THREAT_INTEL_ENABLED_FIELD, detector.getThreatIntelEnabled()) .endObject(); return builder.endObject(); } diff --git a/src/main/java/org/opensearch/securityanalytics/action/IndexDetectorResponse.java b/src/main/java/org/opensearch/securityanalytics/action/IndexDetectorResponse.java index 6a7c268c1..67fe36f0b 100644 --- a/src/main/java/org/opensearch/securityanalytics/action/IndexDetectorResponse.java +++ b/src/main/java/org/opensearch/securityanalytics/action/IndexDetectorResponse.java @@ -64,6 +64,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws .field(Detector.TRIGGERS_FIELD, detector.getTriggers()) .field(Detector.LAST_UPDATE_TIME_FIELD, detector.getLastUpdateTime()) .field(Detector.ENABLED_TIME_FIELD, detector.getEnabledTime()) + .field(Detector.THREAT_INTEL_ENABLED_FIELD, detector.getThreatIntelEnabled()) .endObject(); return builder.endObject(); } diff --git a/src/main/java/org/opensearch/securityanalytics/config/monitors/opensearch_security.policy b/src/main/java/org/opensearch/securityanalytics/config/monitors/opensearch_security.policy new file mode 100644 index 000000000..c5af78398 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/config/monitors/opensearch_security.policy @@ -0,0 +1,3 @@ +grant { + permission java.lang.management.ManagementPermission "reputation.alienvault.com:443" "connect,resolve"; +}; \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/model/Detector.java b/src/main/java/org/opensearch/securityanalytics/model/Detector.java index ff832d1e7..65e4d18be 100644 --- a/src/main/java/org/opensearch/securityanalytics/model/Detector.java +++ b/src/main/java/org/opensearch/securityanalytics/model/Detector.java @@ -25,14 +25,11 @@ import java.io.IOException; import java.time.Instant; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Objects; -import java.util.stream.Collectors; - public class Detector implements Writeable, ToXContentObject { private static final Logger log = LogManager.getLogger(Detector.class); @@ -51,6 +48,7 @@ public class Detector implements Writeable, ToXContentObject { public static final String TRIGGERS_FIELD = "triggers"; public static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; public static final String ENABLED_TIME_FIELD = "enabled_time"; + public static final String THREAT_INTEL_ENABLED_FIELD = "threat_intel_enabled"; public static final String ALERTING_MONITOR_ID = "monitor_id"; public static final String ALERTING_WORKFLOW_ID = "workflow_ids"; @@ -118,11 +116,14 @@ public class Detector implements Writeable, ToXContentObject { private final String type; + private final Boolean threatIntelEnabled; + public Detector(String id, Long version, String name, Boolean enabled, Schedule schedule, Instant lastUpdateTime, Instant enabledTime, String logType, User user, List inputs, List triggers, List monitorIds, String ruleIndex, String alertsIndex, String alertsHistoryIndex, String alertsHistoryIndexPattern, - String findingsIndex, String findingsIndexPattern, Map rulePerMonitor, List workflowIds) { + String findingsIndex, String findingsIndexPattern, Map rulePerMonitor, + List workflowIds, Boolean threatIntelEnabled) { this.type = DETECTOR_TYPE; this.id = id != null ? id : NO_ID; @@ -145,6 +146,7 @@ public Detector(String id, Long version, String name, Boolean enabled, Schedule this.ruleIdMonitorIdMap = rulePerMonitor; this.logType = logType; this.workflowIds = workflowIds != null ? workflowIds : null; + this.threatIntelEnabled = threatIntelEnabled != null && threatIntelEnabled; if (enabled) { Objects.requireNonNull(enabledTime); @@ -172,7 +174,8 @@ public Detector(StreamInput sin) throws IOException { sin.readString(), sin.readString(), sin.readMap(StreamInput::readString, StreamInput::readString), - sin.readStringList() + sin.readStringList(), + sin.readOptionalBoolean() ); } @@ -211,6 +214,7 @@ public void writeTo(StreamOutput out) throws IOException { if (workflowIds != null) { out.writeStringCollection(workflowIds); } + out.writeOptionalBoolean(threatIntelEnabled); } public XContentBuilder toXContentWithUser(XContentBuilder builder, Params params) throws IOException { @@ -239,6 +243,7 @@ private XContentBuilder createXContentBuilder(XContentBuilder builder, ToXConten } } + builder.field(THREAT_INTEL_ENABLED_FIELD, threatIntelEnabled); builder.field(ENABLED_FIELD, enabled); if (enabledTime == null) { @@ -280,7 +285,6 @@ private XContentBuilder createXContentBuilder(XContentBuilder builder, ToXConten builder.field(FINDINGS_INDEX, findingsIndex); builder.field(FINDINGS_INDEX_PATTERN, findingsIndexPattern); - if (params.paramAsBoolean("with_type", false)) { builder.endObject(); } @@ -327,6 +331,7 @@ public static Detector parse(XContentParser xcp, String id, Long version) throws String alertsHistoryIndexPattern = null; String findingsIndex = null; String findingsIndexPattern = null; + Boolean enableThreatIntel = false; XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.currentToken(), xcp); while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { @@ -350,6 +355,9 @@ public static Detector parse(XContentParser xcp, String id, Long version) throws case ENABLED_FIELD: enabled = xcp.booleanValue(); break; + case THREAT_INTEL_ENABLED_FIELD: + enableThreatIntel = xcp.booleanValue(); + break; case SCHEDULE_FIELD: schedule = Schedule.parse(xcp); break; @@ -459,7 +467,8 @@ public static Detector parse(XContentParser xcp, String id, Long version) throws findingsIndex, findingsIndexPattern, rulePerMonitor, - workflowIds + workflowIds, + enableThreatIntel ); } @@ -612,6 +621,10 @@ public boolean isWorkflowSupported() { return workflowIds != null && !workflowIds.isEmpty(); } + public Boolean getThreatIntelEnabled() { + return threatIntelEnabled; + } + @Override public boolean equals(Object o) { if (this == o) return true; diff --git a/src/main/java/org/opensearch/securityanalytics/model/ThreatIntelFeedData.java b/src/main/java/org/opensearch/securityanalytics/model/ThreatIntelFeedData.java new file mode 100644 index 000000000..1870f383a --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/model/ThreatIntelFeedData.java @@ -0,0 +1,159 @@ +package org.opensearch.securityanalytics.model; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.core.ParseField; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.xcontent.XContentParserUtils; + +import java.io.IOException; +import java.time.Instant; +import java.util.Locale; +import java.util.Objects; + +/** + * Model for threat intel feed data stored in system index. + */ +public class ThreatIntelFeedData implements Writeable, ToXContentObject { + private static final Logger log = LogManager.getLogger(ThreatIntelFeedData.class); + private static final String FEED_TYPE = "feed"; + private static final String TYPE_FIELD = "type"; + private static final String IOC_TYPE_FIELD = "ioc_type"; + private static final String IOC_VALUE_FIELD = "ioc_value"; + private static final String FEED_ID_FIELD = "feed_id"; + private static final String TIMESTAMP_FIELD = "timestamp"; + + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( + ThreatIntelFeedData.class, + new ParseField(FEED_TYPE), + xcp -> parse(xcp, null, null) + ); + + private final String iocType; + private final String iocValue; + private final String feedId; + private final Instant timestamp; + private final String type; + + public ThreatIntelFeedData(String iocType, String iocValue, String feedId, Instant timestamp) { + this.type = FEED_TYPE; + + this.iocType = iocType; + this.iocValue = iocValue; + this.feedId = feedId; + this.timestamp = timestamp; + } + + public static ThreatIntelFeedData parse(XContentParser xcp, String id, Long version) throws IOException { + String iocType = null; + String iocValue = null; + String feedId = null; + Instant timestamp = null; + + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = xcp.currentName(); + xcp.nextToken(); + + switch (fieldName) { + case IOC_TYPE_FIELD: + iocType = xcp.text(); + break; + case IOC_VALUE_FIELD: + iocValue = xcp.text(); + break; + case FEED_ID_FIELD: + feedId = xcp.text(); + break; + case TIMESTAMP_FIELD: + if (xcp.currentToken() == XContentParser.Token.VALUE_NULL) { + timestamp = null; + } else if (xcp.currentToken().isValue()) { + timestamp = Instant.ofEpochMilli(xcp.longValue()); + } else { + XContentParserUtils.throwUnknownToken(xcp.currentToken(), xcp.getTokenLocation()); + timestamp = null; + } + break; + default: + xcp.skipChildren(); + } + } + return new ThreatIntelFeedData(iocType, iocValue, feedId, timestamp); + } + + public String getIocType() { + return iocType; + } + + public String getIocValue() { + return iocValue; + } + + public String getFeedId() { + return feedId; + } + + public Instant getTimestamp() { + return timestamp; + } + + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(iocType); + out.writeString(iocValue); + out.writeString(feedId); + out.writeInstant(timestamp); + } + + public ThreatIntelFeedData(StreamInput sin) throws IOException { + this( + sin.readString(), + sin.readString(), + sin.readString(), + sin.readInstant() + ); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return createXContentBuilder(builder, params); + } + + private XContentBuilder createXContentBuilder(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + if (params.paramAsBoolean("with_type", false)) { + builder.startObject(type); + } + builder.field(TYPE_FIELD, type); + builder + .field(IOC_TYPE_FIELD, iocType) + .field(IOC_VALUE_FIELD, iocValue) + .field(FEED_ID_FIELD, feedId) + .timeField(TIMESTAMP_FIELD, String.format(Locale.getDefault(), "%s_in_millis", TIMESTAMP_FIELD), timestamp.toEpochMilli()); + + return builder.endObject(); + } + + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ThreatIntelFeedData tif = (ThreatIntelFeedData) o; + return Objects.equals(iocType, tif.iocType) && Objects.equals(iocValue, tif.iocValue) && Objects.equals(feedId, tif.feedId); + } + + @Override + public int hashCode() { + return Objects.hash(); + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/settings/SecurityAnalyticsSettings.java b/src/main/java/org/opensearch/securityanalytics/settings/SecurityAnalyticsSettings.java index 4085d7ae2..0595375a0 100644 --- a/src/main/java/org/opensearch/securityanalytics/settings/SecurityAnalyticsSettings.java +++ b/src/main/java/org/opensearch/securityanalytics/settings/SecurityAnalyticsSettings.java @@ -4,10 +4,14 @@ */ package org.opensearch.securityanalytics.settings; +import java.net.MalformedURLException; +import java.net.URISyntaxException; +import java.net.URL; +import java.util.List; import java.util.concurrent.TimeUnit; import org.opensearch.common.settings.Setting; import org.opensearch.common.unit.TimeValue; -import org.opensearch.securityanalytics.model.FieldMappingDoc; +import org.opensearch.jobscheduler.repackage.com.cronutils.utils.VisibleForTesting; public class SecurityAnalyticsSettings { public static final String CORRELATION_INDEX = "index.correlation"; @@ -117,4 +121,84 @@ public class SecurityAnalyticsSettings { "ecs", Setting.Property.NodeScope, Setting.Property.Dynamic ); + + // threat intel settings + /** + * Default endpoint to be used in threat intel feed datasource creation API + */ + public static final Setting DATASOURCE_ENDPOINT = Setting.simpleString( + "plugins.security_analytics.threatintel.datasource.endpoint", + "https://feodotracker.abuse.ch/downloads/ipblocklist_aggressive.csv", //TODO: fix this endpoint + new DatasourceEndpointValidator(), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * Default update interval to be used in threat intel datasource creation API + */ + public static final Setting DATASOURCE_UPDATE_INTERVAL = Setting.longSetting( + "plugins.security_analytics.threatintel.datasource.update_interval_in_days", + 3l, + 1l, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * Bulk size for indexing threat intel feed data + */ + public static final Setting BATCH_SIZE = Setting.intSetting( + "plugins.security_analytics.threatintel.datasource.batch_size", + 10000, + 1, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * Timeout value for threat intel processor + */ + public static final Setting THREAT_INTEL_TIMEOUT = Setting.timeSetting( + "plugins.security_analytics.threat_intel_timeout", + TimeValue.timeValueSeconds(30), + TimeValue.timeValueSeconds(1), + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * Max size for geo data cache + */ + public static final Setting CACHE_SIZE = Setting.longSetting( + "plugins.geospatial.ip2geo.processor.cache_size", + 1000, + 0, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); + + /** + * Return all settings of Ip2Geo feature + * @return a list of all settings for Ip2Geo feature + */ + public static final List> settings() { + return List.of(DATASOURCE_ENDPOINT, DATASOURCE_UPDATE_INTERVAL, BATCH_SIZE, CACHE_SIZE); + } + + /** + * Visible for testing + */ + @VisibleForTesting + protected static class DatasourceEndpointValidator implements Setting.Validator { + @Override + public void validate(final String value) { + try { + new URL(value).toURI(); + } catch (MalformedURLException | URISyntaxException e) { + throw new IllegalArgumentException("Invalid URL format is provided"); + } + } + } + } \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/DetectorThreatIntelService.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/DetectorThreatIntelService.java new file mode 100644 index 000000000..0e940988e --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/DetectorThreatIntelService.java @@ -0,0 +1,61 @@ +package org.opensearch.securityanalytics.threatIntel; + +import org.opensearch.commons.alerting.model.DocLevelQuery; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.securityanalytics.model.Detector; +import org.opensearch.securityanalytics.model.ThreatIntelFeedData; +import org.opensearch.securityanalytics.util.SecurityAnalyticsException; + +import java.util.Collections; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + + +public class DetectorThreatIntelService { + + private final ThreatIntelFeedDataService threatIntelFeedDataService; + + public DetectorThreatIntelService(ThreatIntelFeedDataService threatIntelFeedDataService) { + this.threatIntelFeedDataService = threatIntelFeedDataService; + } + + /** Convert the feed data IOCs into query string query format to create doc level queries. */ + public DocLevelQuery createDocLevelQueryFromThreatIntelList( + List tifdList, String docLevelQueryId + ) { + Set iocs = tifdList.stream().map(ThreatIntelFeedData::getIocValue).collect(Collectors.toSet()); + String query = buildQueryStringQueryWithIocList(iocs); + return new DocLevelQuery( + docLevelQueryId,tifdList.get(0).getFeedId(), query, + Collections.singletonList("threat_intel") + ); + } + + private String buildQueryStringQueryWithIocList(Set iocs) { + StringBuilder sb = new StringBuilder(); + + for(String ioc : iocs) { + if(sb.length() != 0) { + sb.append(" "); + } + sb.append("("); + sb.append(ioc); + sb.append(")"); + } + return sb.toString(); + } + + public DocLevelQuery createDocLevelQueryFromThreatIntel(Detector detector) { + // for testing validation only. + if(detector.getThreatIntelEnabled() ==false) { + throw new SecurityAnalyticsException( + "trying to create threat intel feed queries when flag to use threat intel is disabled.", + RestStatus.FORBIDDEN, new IllegalArgumentException()); + + } + // TODO: plugin logic to run job for populating threat intel feed data + /*threatIntelFeedDataService.getThreatIntelFeedData("ip_address", );*/ + return null; + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelFeedDataService.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelFeedDataService.java new file mode 100644 index 000000000..dac0b1b70 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelFeedDataService.java @@ -0,0 +1,294 @@ +package org.opensearch.securityanalytics.threatIntel; + +import org.apache.commons.csv.CSVRecord; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.OpenSearchException; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.client.Client; +import org.opensearch.client.Requests; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.securityanalytics.findings.FindingsService; +import org.opensearch.securityanalytics.model.ThreatIntelFeedData; +import org.opensearch.securityanalytics.threatIntel.common.DatasourceManifest; +import org.opensearch.securityanalytics.threatIntel.common.StashedThreadContext; +import org.opensearch.securityanalytics.threatIntel.dao.DatasourceDao; +import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; +import org.opensearch.securityanalytics.util.IndexUtils; +import org.opensearch.securityanalytics.util.SecurityAnalyticsException; +import org.opensearch.securityanalytics.threatIntel.common.Constants; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.net.URL; +import java.net.URLConnection; +import java.nio.charset.StandardCharsets; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.time.Instant; +import java.util.*; +import java.util.stream.Collectors; + +import static org.opensearch.securityanalytics.threatIntel.jobscheduler.Datasource.THREAT_INTEL_DATA_INDEX_NAME_PREFIX; + +/** + * Service to handle CRUD operations on Threat Intel Feed Data + */ +public class ThreatIntelFeedDataService { + private static final Logger log = LogManager.getLogger(FindingsService.class); + + private final ClusterState state; + private final Client client; + private final IndexNameExpressionResolver indexNameExpressionResolver; + + private static final Map INDEX_SETTING_TO_CREATE = Map.of( + "index.number_of_shards", + 1, + "index.number_of_replicas", + 0, + "index.refresh_interval", + -1, + "index.hidden", + true + ); + private static final Map INDEX_SETTING_TO_FREEZE = Map.of( + "index.auto_expand_replicas", + "0-all", + "index.blocks.write", + true + ); + private final ClusterService clusterService; + private final ClusterSettings clusterSettings; + + public ThreatIntelFeedDataService( + ClusterState state, + ClusterService clusterService, + Client client, + IndexNameExpressionResolver indexNameExpressionResolver, + NamedXContentRegistry xContentRegistry) { + this.state = state; + this.client = client; + this.indexNameExpressionResolver = indexNameExpressionResolver; + this.xContentRegistry = xContentRegistry; + this.clusterService = clusterService; + this.clusterSettings = clusterService.getClusterSettings(); + } + + private final NamedXContentRegistry xContentRegistry; + + public void getThreatIntelFeedData( + String iocType, + ActionListener> listener + ) { + String tifdIndex = IndexUtils.getNewIndexByCreationDate( + this.state, + this.indexNameExpressionResolver, + ".opensearch-sap-threatintel*" //name? + ); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(QueryBuilders.boolQuery().filter(QueryBuilders.termQuery("ioc_type", iocType))); + SearchRequest searchRequest = new SearchRequest(tifdIndex); + searchRequest.source().size(9999); //TODO: convert to scroll + searchRequest.source(sourceBuilder); + client.search(searchRequest, ActionListener.wrap(r -> listener.onResponse(getTifdList(r)), e -> { + log.error(String.format( + "Failed to fetch threat intel feed data from system index %s", tifdIndex), e); + listener.onFailure(e); + })); + } + + private List getTifdList(SearchResponse searchResponse) { + List list = new ArrayList<>(); + if (searchResponse.getHits().getHits().length != 0) { + Arrays.stream(searchResponse.getHits().getHits()).forEach(hit -> { + try { + XContentParser xcp = XContentType.JSON.xContent().createParser( + xContentRegistry, + LoggingDeprecationHandler.INSTANCE, hit.getSourceAsString() + ); + list.add(ThreatIntelFeedData.parse(xcp, hit.getId(), hit.getVersion())); + } catch (Exception e) { + log.error(() -> new ParameterizedMessage( + "Failed to parse Threat intel feed data doc from hit {}", hit), + e + ); + } + + }); + } + return list; + } + + + + + /** + * Create an index for a threat intel feed + * + * Index setting start with single shard, zero replica, no refresh interval, and hidden. + * Once the threat intel feed is indexed, do refresh and force merge. + * Then, change the index setting to expand replica to all nodes, and read only allow delete. + * + * @param indexName index name + */ + public void createIndexIfNotExists(final String indexName) { + if (clusterService.state().metadata().hasIndex(indexName) == true) { + return; + } + final CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName).settings(INDEX_SETTING_TO_CREATE) + .mapping(getIndexMapping()); + StashedThreadContext.run( + client, + () -> client.admin().indices().create(createIndexRequest).actionGet(clusterSettings.get(SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT)) + ); + } + private String getIndexMapping() { + try { + try (InputStream is = DatasourceDao.class.getResourceAsStream("/mappings/threat_intel_feed_mapping.json")) { // TODO: check Datasource dao and this mapping + try (BufferedReader reader = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8))) { + return reader.lines().map(String::trim).collect(Collectors.joining()); + } + } + } catch (IOException e) { + log.error("Runtime exception when getting the threat intel index mapping", e); + throw new SecurityAnalyticsException("Runtime exception when getting the threat intel index mapping", RestStatus.INTERNAL_SERVER_ERROR, e); + } + } + + /** + * Puts threat intel feed from CSVRecord iterator into a given index in bulk + * + * @param indexName Index name to puts the TIF data + * @param fields Field name matching with data in CSVRecord in order + * @param iterator TIF data to insert + * @param renewLock Runnable to renew lock + */ + public void saveThreatIntelFeedDataCSV( + final String indexName, + final String[] fields, + final Iterator iterator, + final Runnable renewLock, + final DatasourceManifest manifest + ) throws IOException { + if (indexName == null || fields == null || iterator == null || renewLock == null){ + throw new IllegalArgumentException("Parameters cannot be null, failed to save threat intel feed data"); + } + + TimeValue timeout = clusterSettings.get(SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT); + Integer batchSize = clusterSettings.get(SecurityAnalyticsSettings.BATCH_SIZE); + final BulkRequest bulkRequest = new BulkRequest(); + Queue requests = new LinkedList<>(); + for (int i = 0; i < batchSize; i++) { + requests.add(Requests.indexRequest(indexName)); + } + + while (iterator.hasNext()) { + CSVRecord record = iterator.next(); + String iocType = ""; + if (manifest.getContainedIocs().get(0) == "ip") { //TODO: dynamically get the type + iocType = "ip"; + } + Integer colNum = Integer.parseInt(manifest.getIocCol()); + String iocValue = record.values()[colNum]; + String feedId = manifest.getFeedId(); + Instant timestamp = Instant.now(); + + ThreatIntelFeedData threatIntelFeedData = new ThreatIntelFeedData(iocType, iocValue, feedId, timestamp); + XContentBuilder tifData = threatIntelFeedData.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + IndexRequest indexRequest = (IndexRequest) requests.poll(); + indexRequest.source(tifData); + indexRequest.id(record.get(0)); + bulkRequest.add(indexRequest); + if (iterator.hasNext() == false || bulkRequest.requests().size() == batchSize) { + BulkResponse response = StashedThreadContext.run(client, () -> client.bulk(bulkRequest).actionGet(timeout)); + if (response.hasFailures()) { + throw new OpenSearchException( + "error occurred while ingesting threat intel feed data in {} with an error {}", + indexName, + response.buildFailureMessage() + ); + } + requests.addAll(bulkRequest.requests()); + bulkRequest.requests().clear(); + } + renewLock.run(); + } + freezeIndex(indexName); + } + + private void freezeIndex(final String indexName) { + TimeValue timeout = clusterSettings.get(SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT); + StashedThreadContext.run(client, () -> { + client.admin().indices().prepareForceMerge(indexName).setMaxNumSegments(1).execute().actionGet(timeout); + client.admin().indices().prepareRefresh(indexName).execute().actionGet(timeout); + client.admin() + .indices() + .prepareUpdateSettings(indexName) + .setSettings(INDEX_SETTING_TO_FREEZE) + .execute() + .actionGet(clusterSettings.get(SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT)); + }); + } + + public void deleteThreatIntelDataIndex(final String index) { + deleteThreatIntelDataIndex(Arrays.asList(index)); + } + + public void deleteThreatIntelDataIndex(final List indices) { + if (indices == null || indices.isEmpty()) { + return; + } + + Optional invalidIndex = indices.stream() + .filter(index -> index.startsWith(THREAT_INTEL_DATA_INDEX_NAME_PREFIX) == false) + .findAny(); + if (invalidIndex.isPresent()) { + throw new OpenSearchException( + "the index[{}] is not threat intel data index which should start with {}", + invalidIndex.get(), + THREAT_INTEL_DATA_INDEX_NAME_PREFIX + ); + } + + AcknowledgedResponse response = StashedThreadContext.run( + client, + () -> client.admin() + .indices() + .prepareDelete(indices.toArray(new String[0])) + .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN_CLOSED_HIDDEN) + .execute() + .actionGet(clusterSettings.get(SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT)) + ); + + if (response.isAcknowledged() == false) { + throw new OpenSearchException("failed to delete data[{}] in datasource", String.join(",", indices)); + } + } + +} diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelFeedParser.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelFeedParser.java new file mode 100644 index 000000000..07e93eea3 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelFeedParser.java @@ -0,0 +1,76 @@ +package org.opensearch.securityanalytics.threatIntel; + +import com.fasterxml.jackson.core.JsonParser; +import org.apache.commons.csv.CSVFormat; +import org.apache.commons.csv.CSVParser; +import org.apache.commons.csv.CSVRecord; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchException; +import org.opensearch.SpecialPermission; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.compress.Compressor; +import org.opensearch.core.compress.CompressorRegistry; +import org.opensearch.core.xcontent.*; +import org.opensearch.securityanalytics.model.DetectorTrigger; +import org.opensearch.securityanalytics.model.ThreatIntelFeedData; +import org.opensearch.securityanalytics.threatIntel.common.Constants; +import org.opensearch.securityanalytics.threatIntel.common.DatasourceManifest; +import org.opensearch.securityanalytics.threatIntel.common.StashedThreadContext; + +import java.io.*; +import java.net.URL; +import java.net.URLConnection; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.time.Instant; + +//Parser helper class +public class ThreatIntelFeedParser { + private static final Logger log = LogManager.getLogger(DetectorTrigger.class); + + /** + * Create CSVParser of a threat intel feed + * + * @param manifest Datasource manifest + * @return parser for threat intel feed + */ + @SuppressForbidden(reason = "Need to connect to http endpoint to read threat intel feed database file") + public static CSVParser getThreatIntelFeedReaderCSV(final DatasourceManifest manifest) { + SpecialPermission.check(); + return AccessController.doPrivileged((PrivilegedAction) () -> { + try { + URL url = new URL(manifest.getUrl()); + URLConnection connection = url.openConnection(); + connection.addRequestProperty(Constants.USER_AGENT_KEY, Constants.USER_AGENT_VALUE); + return new CSVParser(new BufferedReader(new InputStreamReader(connection.getInputStream())), CSVFormat.RFC4180); + } catch (IOException e) { + log.error("Exception: failed to read threat intel feed data from {}",manifest.getUrl(), e); + throw new OpenSearchException("failed to read threat intel feed data from {}", manifest.getUrl(), e); + } + }); + } + + /** + * Validate header + * + * 1. header should not be null + * 2. the number of values in header should be more than one + * + * @param header the header + * @return CSVRecord the input header + */ + public static CSVRecord validateHeader(CSVRecord header) { + if (header == null) { + throw new OpenSearchException("threat intel feed database is empty"); + } + if (header.values().length < 2) { + throw new OpenSearchException("threat intel feed database should have at least two fields"); + } + return header; + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/common/Constants.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/common/Constants.java new file mode 100644 index 000000000..af31e7897 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/common/Constants.java @@ -0,0 +1,9 @@ +package org.opensearch.securityanalytics.threatIntel.common; + +import org.opensearch.Version; + +import java.util.Locale; +public class Constants { + public static final String USER_AGENT_KEY = "User-Agent"; + public static final String USER_AGENT_VALUE = String.format(Locale.ROOT, "OpenSearch/%s vanilla", Version.CURRENT.toString()); +} diff --git a/src/main/java/org/opensearch/securityanalytics/threatintel/common/DatasourceManifest.java b/src/main/java/org/opensearch/securityanalytics/threatintel/common/DatasourceManifest.java new file mode 100644 index 000000000..5835385bf --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatintel/common/DatasourceManifest.java @@ -0,0 +1,205 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.threatIntel.common; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.net.URL; +import java.net.URLConnection; +import java.nio.CharBuffer; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.List; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.SpecialPermission; +import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.ParseField; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ConstructingObjectParser; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.securityanalytics.model.DetectorTrigger; +import org.opensearch.securityanalytics.util.SecurityAnalyticsException; + +/** + * Threat intel datasource manifest file object + * + * Manifest file is stored in an external endpoint. OpenSearch read the file and store values it in this object. + */ +public class DatasourceManifest { + private static final Logger log = LogManager.getLogger(DetectorTrigger.class); + + private static final ParseField FEED_ID = new ParseField("id"); + private static final ParseField URL_FIELD = new ParseField("url"); + private static final ParseField NAME = new ParseField("name"); + private static final ParseField ORGANIZATION = new ParseField("organization"); + private static final ParseField DESCRIPTION = new ParseField("description"); + private static final ParseField FEED_TYPE = new ParseField("feed_type"); + private static final ParseField CONTAINED_IOCS = new ParseField("contained_iocs"); + private static final ParseField IOC_COL = new ParseField("ioc_col"); + + /** + * @param feedId ID of the threat intel feed data + * @return ID of the threat intel feed data + */ + private String feedId; + + /** + * @param url URL of the threat intel feed data + * @return URL of the threat intel feed data + */ + private String url; + + /** + * @param name Name of the threat intel feed + * @return Name of the threat intel feed + */ + private String name; + + /** + * @param organization A threat intel feed organization name + * @return A threat intel feed organization name + */ + private String organization; + + /** + * @param description A description of the database + * @return A description of a database + */ + private String description; + + /** + * @param feedType The type of the data feed (csv, json...) + * @return The type of the data feed (csv, json...) + */ + private String feedType; + + /** + * @param iocCol the column of the ioc data if feedType is csv + * @return the column of the ioc data if feedType is csv + */ + private String iocCol; + + /** + * @param containedIocs list of ioc types contained in feed + * @return list of ioc types contained in feed + */ + private List containedIocs; + + + public String getUrl() { + return url; + } + public String getName() { + return name; + } + public String getOrganization() { + return organization; + } + public String getDescription() { + return description; + } + public String getFeedId() { + return feedId; + } + public String getFeedType() { + return feedType; + } + public String getIocCol() { + return iocCol; + } + public List getContainedIocs() { + return containedIocs; + } + + public DatasourceManifest(final String feedId, final String url, final String name, final String organization, final String description, final String feedType, final List containedIocs, final String iocCol) { + this.feedId = feedId; + this.url = url; + this.name = name; + this.organization = organization; + this.description = description; + this.feedType = feedType; + this.containedIocs = containedIocs; + this.iocCol = iocCol; + } + + /** + * Datasource manifest parser + */ + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "datasource_manifest", + true, + args -> { + String feedId = (String) args[0]; + String url = (String) args[1]; + String name = (String) args[2]; + String organization = (String) args[3]; + String description = (String) args[4]; + String feedType = (String) args[5]; + List containedIocs = (List) args[6]; + String iocCol = (String) args[7]; + return new DatasourceManifest(feedId, url, name, organization, description, feedType, containedIocs, iocCol); + } + ); + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), FEED_ID); + PARSER.declareString(ConstructingObjectParser.constructorArg(), URL_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), NAME); + PARSER.declareString(ConstructingObjectParser.constructorArg(), ORGANIZATION); + PARSER.declareString(ConstructingObjectParser.constructorArg(), DESCRIPTION); + PARSER.declareString(ConstructingObjectParser.constructorArg(), FEED_TYPE); + PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), CONTAINED_IOCS); + PARSER.declareString(ConstructingObjectParser.constructorArg(), IOC_COL); + } + + /** + * Datasource manifest builder + */ + public static class Builder { //TODO: builder? + private static final int MANIFEST_FILE_MAX_BYTES = 1024 * 8; + + /** + * Build DatasourceManifest from a given url + * + * @param url url to downloads a manifest file + * @return DatasourceManifest representing the manifest file + */ + @SuppressForbidden(reason = "Need to connect to http endpoint to read manifest file") // change permissions + public static DatasourceManifest build(final URL url) { + SpecialPermission.check(); + return AccessController.doPrivileged((PrivilegedAction) () -> { + try { + URLConnection connection = url.openConnection(); + return internalBuild(connection); + } catch (IOException e) { + log.error("Runtime exception connecting to the manifest file", e); + throw new SecurityAnalyticsException("Runtime exception", RestStatus.INTERNAL_SERVER_ERROR, e); //TODO + } + }); + } + + @SuppressForbidden(reason = "Need to connect to http endpoint to read manifest file") + protected static DatasourceManifest internalBuild(final URLConnection connection) throws IOException { + connection.addRequestProperty(Constants.USER_AGENT_KEY, Constants.USER_AGENT_VALUE); + InputStreamReader inputStreamReader = new InputStreamReader(connection.getInputStream()); + try (BufferedReader reader = new BufferedReader(inputStreamReader)) { + CharBuffer charBuffer = CharBuffer.allocate(MANIFEST_FILE_MAX_BYTES); + reader.read(charBuffer); + charBuffer.flip(); + XContentParser parser = JsonXContent.jsonXContent.createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + charBuffer.toString() + ); + return PARSER.parse(parser, null); + } + } + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/threatintel/common/DatasourceState.java b/src/main/java/org/opensearch/securityanalytics/threatintel/common/DatasourceState.java new file mode 100644 index 000000000..a516b1d34 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatintel/common/DatasourceState.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.common; + +/** + * Threat intel datasource state + * + * When data source is created, it starts with CREATING state. Once the first threat intel feed is generated, the state changes to AVAILABLE. + * Only when the first threat intel feed generation failed, the state changes to CREATE_FAILED. + * Subsequent threat intel feed failure won't change data source state from AVAILABLE to CREATE_FAILED. + * When delete request is received, the data source state changes to DELETING. + * + * State changed from left to right for the entire lifecycle of a datasource + * (CREATING) to (CREATE_FAILED or AVAILABLE) to (DELETING) + * + */ +public enum DatasourceState { + /** + * Data source is being created + */ + CREATING, + /** + * Data source is ready to be used + */ + AVAILABLE, + /** + * Data source creation failed + */ + CREATE_FAILED, + /** + * Data source is being deleted + */ + DELETING +} diff --git a/src/main/java/org/opensearch/securityanalytics/threatintel/common/StashedThreadContext.java b/src/main/java/org/opensearch/securityanalytics/threatintel/common/StashedThreadContext.java new file mode 100644 index 000000000..32f4e6d40 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatintel/common/StashedThreadContext.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.common; + +import java.util.function.Supplier; + +import org.opensearch.client.Client; +import org.opensearch.common.util.concurrent.ThreadContext; + +/** + * Helper class to run code with stashed thread context + * + * Code need to be run with stashed thread context if it interacts with system index + * when security plugin is enabled. + */ +public class StashedThreadContext { + /** + * Set the thread context to default, this is needed to allow actions on model system index + * when security plugin is enabled + * @param function runnable that needs to be executed after thread context has been stashed, accepts and returns nothing + */ + public static void run(final Client client, final Runnable function) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + function.run(); + } + } + + /** + * Set the thread context to default, this is needed to allow actions on model system index + * when security plugin is enabled + * @param function supplier function that needs to be executed after thread context has been stashed, return object + */ + public static T run(final Client client, final Supplier function) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + return function.get(); + } + } +} + diff --git a/src/main/java/org/opensearch/securityanalytics/threatintel/common/ThreatIntelExecutor.java b/src/main/java/org/opensearch/securityanalytics/threatintel/common/ThreatIntelExecutor.java new file mode 100644 index 000000000..09c916389 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatintel/common/ThreatIntelExecutor.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.common; + +import java.util.concurrent.ExecutorService; + +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ExecutorBuilder; +import org.opensearch.threadpool.FixedExecutorBuilder; +import org.opensearch.threadpool.ThreadPool; + +/** + * Provide a list of static methods related with executors for threat intel + */ +public class ThreatIntelExecutor { + private static final String THREAD_POOL_NAME = "plugin_sap_datasource_update"; //TODO: name + private final ThreadPool threadPool; + + public ThreatIntelExecutor(final ThreadPool threadPool) { + this.threadPool = threadPool; + } + + /** + * We use fixed thread count of 1 for updating datasource as updating datasource is running background + * once a day at most and no need to expedite the task. + * + * @param settings the settings + * @return the executor builder + */ + public static ExecutorBuilder executorBuilder(final Settings settings) { + return new FixedExecutorBuilder(settings, THREAD_POOL_NAME, 1, 1000, THREAD_POOL_NAME, false); + } + + /** + * Return an executor service for datasource update task + * + * @return the executor service + */ + public ExecutorService forDatasourceUpdate() { + return threadPool.executor(THREAD_POOL_NAME); + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/threatintel/common/ThreatIntelLockService.java b/src/main/java/org/opensearch/securityanalytics/threatintel/common/ThreatIntelLockService.java new file mode 100644 index 000000000..e3da25879 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatintel/common/ThreatIntelLockService.java @@ -0,0 +1,168 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.common; + +import static org.opensearch.securityanalytics.threatIntel.jobscheduler.DatasourceExtension.JOB_INDEX_NAME; + +import java.time.Instant; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import org.opensearch.OpenSearchException; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.action.ActionListener; +import org.opensearch.jobscheduler.spi.LockModel; +import org.opensearch.jobscheduler.spi.utils.LockService; +import org.opensearch.securityanalytics.model.DetectorTrigger; +import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; + +/** + * A wrapper of job scheduler's lock service for datasource + */ +public class ThreatIntelLockService { + private static final Logger log = LogManager.getLogger(DetectorTrigger.class); + + public static final long LOCK_DURATION_IN_SECONDS = 300l; + public static final long RENEW_AFTER_IN_SECONDS = 120l; + + private final ClusterService clusterService; + private final LockService lockService; + + + /** + * Constructor + * + * @param clusterService the cluster service + * @param client the client + */ + public ThreatIntelLockService(final ClusterService clusterService, final Client client) { + this.clusterService = clusterService; + this.lockService = new LockService(client, clusterService); + } + + /** + * Wrapper method of LockService#acquireLockWithId + * + * Datasource uses its name as doc id in job scheduler. Therefore, we can use datasource name to acquire + * a lock on a datasource. + * + * @param datasourceName datasourceName to acquire lock on + * @param lockDurationSeconds the lock duration in seconds + * @param listener the listener + */ + public void acquireLock(final String datasourceName, final Long lockDurationSeconds, final ActionListener listener) { + lockService.acquireLockWithId(JOB_INDEX_NAME, lockDurationSeconds, datasourceName, listener); + } + + /** + * Synchronous method of #acquireLock + * + * @param datasourceName datasourceName to acquire lock on + * @param lockDurationSeconds the lock duration in seconds + * @return lock model + */ + public Optional acquireLock(final String datasourceName, final Long lockDurationSeconds) { + AtomicReference lockReference = new AtomicReference(); + CountDownLatch countDownLatch = new CountDownLatch(1); + lockService.acquireLockWithId(JOB_INDEX_NAME, lockDurationSeconds, datasourceName, new ActionListener<>() { + @Override + public void onResponse(final LockModel lockModel) { + lockReference.set(lockModel); + countDownLatch.countDown(); + } + + @Override + public void onFailure(final Exception e) { + lockReference.set(null); + countDownLatch.countDown(); + log.error("aquiring lock failed", e); + } + }); + + try { + countDownLatch.await(clusterService.getClusterSettings().get(SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT).getSeconds(), TimeUnit.SECONDS); + return Optional.ofNullable(lockReference.get()); + } catch (InterruptedException e) { + log.error("Waiting for the count down latch failed", e); + return Optional.empty(); + } + } + + /** + * Wrapper method of LockService#release + * + * @param lockModel the lock model + */ + public void releaseLock(final LockModel lockModel) { + lockService.release( + lockModel, + ActionListener.wrap(released -> {}, exception -> log.error("Failed to release the lock", exception)) + ); + } + + /** + * Synchronous method of LockService#renewLock + * + * @param lockModel lock to renew + * @return renewed lock if renew succeed and null otherwise + */ + public LockModel renewLock(final LockModel lockModel) { + AtomicReference lockReference = new AtomicReference(); + CountDownLatch countDownLatch = new CountDownLatch(1); + lockService.renewLock(lockModel, new ActionListener<>() { + @Override + public void onResponse(final LockModel lockModel) { + lockReference.set(lockModel); + countDownLatch.countDown(); + } + + @Override + public void onFailure(final Exception e) { + log.error("failed to renew lock", e); + lockReference.set(null); + countDownLatch.countDown(); + } + }); + + try { + countDownLatch.await(clusterService.getClusterSettings().get(SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT).getSeconds(), TimeUnit.SECONDS); + return lockReference.get(); + } catch (InterruptedException e) { + log.error("Interrupted exception", e); + return null; + } + } + + /** + * Return a runnable which can renew the given lock model + * + * The runnable renews the lock and store the renewed lock in the AtomicReference. + * It only renews the lock when it passed {@code RENEW_AFTER_IN_SECONDS} since + * the last time the lock was renewed to avoid resource abuse. + * + * @param lockModel lock model to renew + * @return runnable which can renew the given lock for every call + */ + public Runnable getRenewLockRunnable(final AtomicReference lockModel) { + return () -> { + LockModel preLock = lockModel.get(); + if (Instant.now().isBefore(preLock.getLockTime().plusSeconds(RENEW_AFTER_IN_SECONDS))) { + return; + } + lockModel.set(renewLock(lockModel.get())); + if (lockModel.get() == null) { + log.error("Exception: failed to renew a lock"); + new OpenSearchException("failed to renew a lock [{}]", preLock); + } + }; + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/threatintel/dao/DatasourceDao.java b/src/main/java/org/opensearch/securityanalytics/threatintel/dao/DatasourceDao.java new file mode 100644 index 000000000..c6d8db18f --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatintel/dao/DatasourceDao.java @@ -0,0 +1,379 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.dao; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchException; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.StepListener; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.get.MultiGetItemResponse; +import org.opensearch.action.get.MultiGetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.routing.Preference; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.securityanalytics.model.DetectorTrigger; +import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.Datasource; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.DatasourceExtension; +import org.opensearch.securityanalytics.threatIntel.common.StashedThreadContext; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.SearchHit; +import org.opensearch.securityanalytics.util.SecurityAnalyticsException; + +/** + * Data access object for datasource + */ +public class DatasourceDao { + private static final Logger log = LogManager.getLogger(DetectorTrigger.class); + + private static final Integer MAX_SIZE = 1000; + private final Client client; + private final ClusterService clusterService; + private final ClusterSettings clusterSettings; + + public DatasourceDao(final Client client, final ClusterService clusterService) { + this.client = client; + this.clusterService = clusterService; + this.clusterSettings = clusterService.getClusterSettings(); + } + + /** + * Create datasource index + * + * @param stepListener setup listener + */ + public void createIndexIfNotExists(final StepListener stepListener) { + if (clusterService.state().metadata().hasIndex(DatasourceExtension.JOB_INDEX_NAME) == true) { + stepListener.onResponse(null); + return; + } + final CreateIndexRequest createIndexRequest = new CreateIndexRequest(DatasourceExtension.JOB_INDEX_NAME).mapping(getIndexMapping()) + .settings(DatasourceExtension.INDEX_SETTING); + StashedThreadContext.run(client, () -> client.admin().indices().create(createIndexRequest, new ActionListener<>() { + @Override + public void onResponse(final CreateIndexResponse createIndexResponse) { + stepListener.onResponse(null); + } + + @Override + public void onFailure(final Exception e) { + if (e instanceof ResourceAlreadyExistsException) { + log.info("index[{}] already exist", DatasourceExtension.JOB_INDEX_NAME); + stepListener.onResponse(null); + return; + } + stepListener.onFailure(e); + } + })); + } //TODO: change this to create a Datasource + private String getIndexMapping() { + try { + try (InputStream is = DatasourceDao.class.getResourceAsStream("/mappings/threat_intel_datasource.json")) { + try (BufferedReader reader = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8))) { + return reader.lines().map(String::trim).collect(Collectors.joining()); + } + } + } catch (IOException e) { + log.error("Runtime exception", e); + throw new SecurityAnalyticsException("Runtime exception", RestStatus.INTERNAL_SERVER_ERROR, e); //TODO + } + } + + /** + * Update datasource in an index {@code DatasourceExtension.JOB_INDEX_NAME} + * @param datasource the datasource + * @return index response + */ + public IndexResponse updateDatasource(final Datasource datasource) { + datasource.setLastUpdateTime(Instant.now()); + return StashedThreadContext.run(client, () -> { + try { + return client.prepareIndex(DatasourceExtension.JOB_INDEX_NAME) + .setId(datasource.getName()) + .setOpType(DocWriteRequest.OpType.INDEX) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .setSource(datasource.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) + .execute() + .actionGet(clusterSettings.get(SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT)); + } catch (IOException e) { + throw new SecurityAnalyticsException("Runtime exception", RestStatus.INTERNAL_SERVER_ERROR, e); //TODO + } + }); + } + + private IndexRequest toIndexRequest(Datasource datasource) { + try { + IndexRequest indexRequest = new IndexRequest(); + indexRequest.index(DatasourceExtension.JOB_INDEX_NAME); + indexRequest.id(datasource.getName()); + indexRequest.opType(DocWriteRequest.OpType.INDEX); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + indexRequest.source(datasource.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); + return indexRequest; + } catch (IOException e) { + throw new SecurityAnalyticsException("Runtime exception", RestStatus.INTERNAL_SERVER_ERROR, e); //TODO + } + } + + /** + * Get datasource from an index {@code DatasourceExtension.JOB_INDEX_NAME} + * @param name the name of a datasource + * @return datasource + * @throws IOException exception + */ + public Datasource getDatasource(final String name) throws IOException { + GetRequest request = new GetRequest(DatasourceExtension.JOB_INDEX_NAME, name); + GetResponse response; + try { + response = StashedThreadContext.run(client, () -> client.get(request).actionGet(clusterSettings.get(SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT))); + if (response.isExists() == false) { + log.error("Datasource[{}] does not exist in an index[{}]", name, DatasourceExtension.JOB_INDEX_NAME); + return null; + } + } catch (IndexNotFoundException e) { + log.error("Index[{}] is not found", DatasourceExtension.JOB_INDEX_NAME); + return null; + } + + XContentParser parser = XContentHelper.createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + response.getSourceAsBytesRef() + ); + return Datasource.PARSER.parse(parser, null); + } + + /** + * Update datasources in an index {@code DatasourceExtension.JOB_INDEX_NAME} + * @param datasources the datasources + * @param listener action listener + */ + public void updateDatasource(final List datasources, final ActionListener listener) { + BulkRequest bulkRequest = new BulkRequest(); + datasources.stream().map(datasource -> { + datasource.setLastUpdateTime(Instant.now()); + return datasource; + }).map(this::toIndexRequest).forEach(indexRequest -> bulkRequest.add(indexRequest)); + StashedThreadContext.run(client, () -> client.bulk(bulkRequest, listener)); + } + + /** + * Put datasource in an index {@code DatasourceExtension.JOB_INDEX_NAME} + * + * @param datasource the datasource + * @param listener the listener + */ + public void putDatasource(final Datasource datasource, final ActionListener listener) { + datasource.setLastUpdateTime(Instant.now()); + StashedThreadContext.run(client, () -> { + try { + client.prepareIndex(DatasourceExtension.JOB_INDEX_NAME) + .setId(datasource.getName()) + .setOpType(DocWriteRequest.OpType.CREATE) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .setSource(datasource.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) + .execute(listener); + } catch (IOException e) { + throw new SecurityAnalyticsException("Runtime exception", RestStatus.INTERNAL_SERVER_ERROR, e); //TODO + } + }); + } // need to use this somewhere + + /** + * Delete datasource in an index {@code DatasourceExtension.JOB_INDEX_NAME} + * + * @param datasource the datasource + * + */ + public void deleteDatasource(final Datasource datasource) { + DeleteResponse response = client.prepareDelete() + .setIndex(DatasourceExtension.JOB_INDEX_NAME) + .setId(datasource.getName()) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .execute() + .actionGet(clusterSettings.get(SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT)); + + if (response.status().equals(RestStatus.OK)) { + log.info("deleted datasource[{}] successfully", datasource.getName()); + } else if (response.status().equals(RestStatus.NOT_FOUND)) { + throw new ResourceNotFoundException("datasource[{}] does not exist", datasource.getName()); + } else { + throw new OpenSearchException("failed to delete datasource[{}] with status[{}]", datasource.getName(), response.status()); + } + } + + /** + * Get datasource from an index {@code DatasourceExtension.JOB_INDEX_NAME} + * @param name the name of a datasource + * @param actionListener the action listener + */ + public void getDatasource(final String name, final ActionListener actionListener) { + GetRequest request = new GetRequest(DatasourceExtension.JOB_INDEX_NAME, name); + StashedThreadContext.run(client, () -> client.get(request, new ActionListener<>() { + @Override + public void onResponse(final GetResponse response) { + if (response.isExists() == false) { + actionListener.onResponse(null); + return; + } + + try { + XContentParser parser = XContentHelper.createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + response.getSourceAsBytesRef() + ); + actionListener.onResponse(Datasource.PARSER.parse(parser, null)); + } catch (IOException e) { + actionListener.onFailure(e); + } + } + + @Override + public void onFailure(final Exception e) { + actionListener.onFailure(e); + } + })); + } + + /** + * Get datasources from an index {@code DatasourceExtension.JOB_INDEX_NAME} + * @param names the array of datasource names + * @param actionListener the action listener + */ + public void getDatasources(final String[] names, final ActionListener> actionListener) { + StashedThreadContext.run( + client, + () -> client.prepareMultiGet() + .add(DatasourceExtension.JOB_INDEX_NAME, names) + .execute(createGetDataSourceQueryActionLister(MultiGetResponse.class, actionListener)) + ); + } + + /** + * Get all datasources up to {@code MAX_SIZE} from an index {@code DatasourceExtension.JOB_INDEX_NAME} + * @param actionListener the action listener + */ + public void getAllDatasources(final ActionListener> actionListener) { + StashedThreadContext.run( + client, + () -> client.prepareSearch(DatasourceExtension.JOB_INDEX_NAME) + .setQuery(QueryBuilders.matchAllQuery()) + .setPreference(Preference.PRIMARY.type()) + .setSize(MAX_SIZE) + .execute(createGetDataSourceQueryActionLister(SearchResponse.class, actionListener)) + ); + } + + /** + * Get all datasources up to {@code MAX_SIZE} from an index {@code DatasourceExtension.JOB_INDEX_NAME} + */ + public List getAllDatasources() { + SearchResponse response = StashedThreadContext.run( + client, + () -> client.prepareSearch(DatasourceExtension.JOB_INDEX_NAME) + .setQuery(QueryBuilders.matchAllQuery()) + .setPreference(Preference.PRIMARY.type()) + .setSize(MAX_SIZE) + .execute() + .actionGet(clusterSettings.get(SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT)) + ); + + List bytesReferences = toBytesReferences(response); + return bytesReferences.stream().map(bytesRef -> toDatasource(bytesRef)).collect(Collectors.toList()); + } + + private ActionListener createGetDataSourceQueryActionLister( + final Class response, + final ActionListener> actionListener + ) { + return new ActionListener() { + @Override + public void onResponse(final T response) { + try { + List bytesReferences = toBytesReferences(response); + List datasources = bytesReferences.stream() + .map(bytesRef -> toDatasource(bytesRef)) + .collect(Collectors.toList()); + actionListener.onResponse(datasources); + } catch (Exception e) { + actionListener.onFailure(e); + } + } + + @Override + public void onFailure(final Exception e) { + actionListener.onFailure(e); + } + }; + } + + private List toBytesReferences(final Object response) { + if (response instanceof SearchResponse) { + SearchResponse searchResponse = (SearchResponse) response; + return Arrays.stream(searchResponse.getHits().getHits()).map(SearchHit::getSourceRef).collect(Collectors.toList()); + } else if (response instanceof MultiGetResponse) { + MultiGetResponse multiGetResponse = (MultiGetResponse) response; + return Arrays.stream(multiGetResponse.getResponses()) + .map(MultiGetItemResponse::getResponse) + .filter(Objects::nonNull) + .filter(GetResponse::isExists) + .map(GetResponse::getSourceAsBytesRef) + .collect(Collectors.toList()); + } else { + throw new OpenSearchException("No supported instance type[{}] is provided", response.getClass()); + } + } + + private Datasource toDatasource(final BytesReference bytesReference) { + try { + XContentParser parser = XContentHelper.createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + bytesReference + ); + return Datasource.PARSER.parse(parser, null); + } catch (IOException e) { + throw new SecurityAnalyticsException("Runtime exception", RestStatus.INTERNAL_SERVER_ERROR, e); //TODO + } + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/threatintel/jobscheduler/Datasource.java b/src/main/java/org/opensearch/securityanalytics/threatintel/jobscheduler/Datasource.java new file mode 100644 index 000000000..20fbd36bc --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatintel/jobscheduler/Datasource.java @@ -0,0 +1,861 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.securityanalytics.threatIntel.jobscheduler; + +import org.opensearch.core.ParseField; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ConstructingObjectParser; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.jobscheduler.spi.schedule.ScheduleParser; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.*; + +import static org.opensearch.common.time.DateUtils.toInstant; + +import org.opensearch.securityanalytics.threatIntel.common.DatasourceManifest; +import org.opensearch.securityanalytics.threatIntel.common.DatasourceState; +import org.opensearch.securityanalytics.threatIntel.common.ThreatIntelLockService; + +public class Datasource implements Writeable, ScheduledJobParameter { + /** + * Prefix of indices having threatIntel data + */ + public static final String THREAT_INTEL_DATA_INDEX_NAME_PREFIX = "opensearch-sap-threatintel"; + + /** + * Default fields for job scheduling + */ + private static final ParseField NAME_FIELD = new ParseField("name"); + private static final ParseField ENABLED_FIELD = new ParseField("update_enabled"); + private static final ParseField LAST_UPDATE_TIME_FIELD = new ParseField("last_update_time"); + private static final ParseField LAST_UPDATE_TIME_FIELD_READABLE = new ParseField("last_update_time_field"); + public static final ParseField SCHEDULE_FIELD = new ParseField("schedule"); + private static final ParseField ENABLED_TIME_FIELD = new ParseField("enabled_time"); + private static final ParseField ENABLED_TIME_FIELD_READABLE = new ParseField("enabled_time_field"); + + /** + * Additional fields for datasource + */ + private static final ParseField STATE_FIELD = new ParseField("state"); + private static final ParseField CURRENT_INDEX_FIELD = new ParseField("current_index"); + private static final ParseField INDICES_FIELD = new ParseField("indices"); + private static final ParseField DATABASE_FIELD = new ParseField("database"); + private static final ParseField UPDATE_STATS_FIELD = new ParseField("update_stats"); + private static final ParseField TASK_FIELD = new ParseField("task"); + + + /** + * Default variables for job scheduling + */ + + /** + * @param name name of a datasource + * @return name of a datasource + */ + private String name; + + /** + * @param lastUpdateTime Last update time of a datasource + * @return Last update time of a datasource + */ + private Instant lastUpdateTime; + /** + * @param enabledTime Last time when a scheduling is enabled for a threat intel feed data update + * @return Last time when a scheduling is enabled for the job scheduler + */ + private Instant enabledTime; + /** + * @param isEnabled Indicate if threat intel feed data update is scheduled or not + * @return Indicate if scheduling is enabled or not + */ + private boolean isEnabled; + /** + * @param schedule Schedule that system uses + * @return Schedule that system uses + */ + private IntervalSchedule schedule; + + + /** + * Additional variables for datasource + */ + + /** + * @param state State of a datasource + * @return State of a datasource + */ + private DatasourceState state; + + /** + * @param currentIndex the current index name having threat intel feed data + * @return the current index name having threat intel feed data + */ + private String currentIndex; + + /** + * @param indices A list of indices having threat intel feed data including currentIndex + * @return A list of indices having threat intel feed data including currentIndex + */ + private List indices; + + /** + * @param database threat intel feed database information + * @return threat intel feed database information + */ + private Database database; + + /** + * @param updateStats threat intel feed database update statistics + * @return threat intel feed database update statistics + */ + private UpdateStats updateStats; + + /** + * @param task Task that {@link DatasourceRunner} will execute + * @return Task that {@link DatasourceRunner} will execute + */ + private DatasourceTask task; + + /** + * Datasource parser + */ + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "datasource_metadata", + true, + args -> { + String name = (String) args[0]; + Instant lastUpdateTime = Instant.ofEpochMilli((long) args[1]); + Instant enabledTime = args[2] == null ? null : Instant.ofEpochMilli((long) args[2]); + boolean isEnabled = (boolean) args[3]; + IntervalSchedule schedule = (IntervalSchedule) args[4]; + DatasourceTask task = DatasourceTask.valueOf((String) args[5]); + DatasourceState state = DatasourceState.valueOf((String) args[6]); + String currentIndex = (String) args[7]; + List indices = (List) args[8]; + Database database = (Database) args[9]; + UpdateStats updateStats = (UpdateStats) args[10]; + Datasource parameter = new Datasource( + name, + lastUpdateTime, + enabledTime, + isEnabled, + schedule, + task, + state, + currentIndex, + indices, + database, + updateStats + ); + return parameter; + } + ); + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), NAME_FIELD); + PARSER.declareLong(ConstructingObjectParser.constructorArg(), LAST_UPDATE_TIME_FIELD); + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), ENABLED_TIME_FIELD); + PARSER.declareBoolean(ConstructingObjectParser.constructorArg(), ENABLED_FIELD); + PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> ScheduleParser.parse(p), SCHEDULE_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), TASK_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), STATE_FIELD); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), CURRENT_INDEX_FIELD); + PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), INDICES_FIELD); + PARSER.declareObject(ConstructingObjectParser.constructorArg(), Database.PARSER, DATABASE_FIELD); + PARSER.declareObject(ConstructingObjectParser.constructorArg(), UpdateStats.PARSER, UPDATE_STATS_FIELD); + } + + public Datasource() { + this(null, null); + } + + public Datasource(final String name, final Instant lastUpdateTime, final Instant enabledTime, final Boolean isEnabled, + final IntervalSchedule schedule, DatasourceTask task, final DatasourceState state, final String currentIndex, + final List indices, final Database database, final UpdateStats updateStats) { + this.name = name; + this.lastUpdateTime = lastUpdateTime; + this.enabledTime = enabledTime; + this.isEnabled = isEnabled; + this.schedule = schedule; + this.task = task; + this.state = state; + this.currentIndex = currentIndex; + this.indices = indices; + this.database = database; + this.updateStats = updateStats; + } + + public Datasource(final String name, final IntervalSchedule schedule) { + this( + name, + Instant.now().truncatedTo(ChronoUnit.MILLIS), + null, + false, + schedule, + DatasourceTask.ALL, + DatasourceState.CREATING, + null, + new ArrayList<>(), + new Database(), + new UpdateStats() + ); + } + + public Datasource(final StreamInput in) throws IOException { + name = in.readString(); + lastUpdateTime = toInstant(in.readVLong()); + enabledTime = toInstant(in.readOptionalVLong()); + isEnabled = in.readBoolean(); + schedule = new IntervalSchedule(in); + task = DatasourceTask.valueOf(in.readString()); + state = DatasourceState.valueOf(in.readString()); + currentIndex = in.readOptionalString(); + indices = in.readStringList(); + database = new Database(in); + updateStats = new UpdateStats(in); + } + + public void writeTo(final StreamOutput out) throws IOException { + out.writeString(name); + out.writeVLong(lastUpdateTime.toEpochMilli()); + out.writeOptionalVLong(enabledTime == null ? null : enabledTime.toEpochMilli()); + out.writeBoolean(isEnabled); + schedule.writeTo(out); + out.writeString(task.name()); + out.writeString(state.name()); + out.writeOptionalString(currentIndex); + out.writeStringCollection(indices); + database.writeTo(out); + updateStats.writeTo(out); + } + + @Override + public XContentBuilder toXContent(final XContentBuilder builder, final Params params) throws IOException { + builder.startObject(); + builder.field(NAME_FIELD.getPreferredName(), name); + builder.timeField( + LAST_UPDATE_TIME_FIELD.getPreferredName(), + LAST_UPDATE_TIME_FIELD_READABLE.getPreferredName(), + lastUpdateTime.toEpochMilli() + ); + if (enabledTime != null) { + builder.timeField( + ENABLED_TIME_FIELD.getPreferredName(), + ENABLED_TIME_FIELD_READABLE.getPreferredName(), + enabledTime.toEpochMilli() + ); + } + builder.field(ENABLED_FIELD.getPreferredName(), isEnabled); + builder.field(SCHEDULE_FIELD.getPreferredName(), schedule); + builder.field(TASK_FIELD.getPreferredName(), task.name()); + builder.field(STATE_FIELD.getPreferredName(), state.name()); + if (currentIndex != null) { + builder.field(CURRENT_INDEX_FIELD.getPreferredName(), currentIndex); + } + builder.field(INDICES_FIELD.getPreferredName(), indices); + builder.field(DATABASE_FIELD.getPreferredName(), database); + builder.field(UPDATE_STATS_FIELD.getPreferredName(), updateStats); + builder.endObject(); + return builder; + } + + // getters and setters + public void setName(String name) { + this.name = name; + } + public void setEnabledTime(Instant enabledTime) { + this.enabledTime = enabledTime; + } + + public void setEnabled(boolean enabled) { + isEnabled = enabled; + } + + public void setIndices(List indices) { + this.indices = indices; + } + + public void setDatabase(Database database) { + this.database = database; + } + public void setUpdateStats(UpdateStats updateStats) { + this.updateStats = updateStats; + } + + @Override + public String getName() { + return this.name; + } + @Override + public Instant getLastUpdateTime() { + return this.lastUpdateTime; + } + @Override + public Instant getEnabledTime() { + return this.enabledTime; + } + @Override + public IntervalSchedule getSchedule() { + return this.schedule; + } + @Override + public boolean isEnabled() { + return this.isEnabled; + } + + public DatasourceTask getTask() { + return task; + } + public void setLastUpdateTime(Instant lastUpdateTime) { + this.lastUpdateTime = lastUpdateTime; + } + public void setCurrentIndex(String currentIndex) { + this.currentIndex = currentIndex; + } + + public void setTask(DatasourceTask task) { + this.task = task; + } + @Override + public Long getLockDurationSeconds() { + return ThreatIntelLockService.LOCK_DURATION_IN_SECONDS; + } + + /** + * Enable auto update of threat intel feed data + */ + public void enable() { + if (isEnabled == true) { + return; + } + enabledTime = Instant.now().truncatedTo(ChronoUnit.MILLIS); + isEnabled = true; + } + + /** + * Disable auto update of threat intel feed data + */ + public void disable() { + enabledTime = null; + isEnabled = false; + } + + /** + * Current index name of a datasource + * + * @return Current index name of a datasource + */ + public String currentIndexName() { + return currentIndex; + } + + public void setSchedule(IntervalSchedule schedule) { + this.schedule = schedule; + } + + /** + * Reset database so that it can be updated in next run regardless there is new update or not + */ + public void resetDatabase() { + database.setFeedId(null); + database.setFeedName(null); + database.setFeedFormat(null); + database.setEndpoint(null); + database.setDescription(null); + database.setOrganization(null); + database.setContained_iocs_field(null); + database.setIocCol(null); + database.setFeedFormat(null); + } + + /** + * Index name for a datasource with given suffix + * + * @param suffix the suffix of a index name + * @return index name for a datasource with given suffix + */ + public String newIndexName(final String suffix) { + return String.format(Locale.ROOT, "%s.%s.%s", THREAT_INTEL_DATA_INDEX_NAME_PREFIX, name, suffix); + } + + /** + * Set database attributes with given input + * + * @param datasourceManifest the datasource manifest + * @param fields the fields + */ + public void setDatabase(final DatasourceManifest datasourceManifest, final List fields) { + this.database.feedId = datasourceManifest.getFeedId(); + this.database.feedName = datasourceManifest.getName(); + this.database.feedFormat = datasourceManifest.getFeedType(); + this.database.endpoint = datasourceManifest.getUrl(); + this.database.organization = datasourceManifest.getOrganization(); + this.database.description = datasourceManifest.getDescription(); + this.database.contained_iocs_field = datasourceManifest.getContainedIocs(); + this.database.iocCol = datasourceManifest.getIocCol(); + this.database.fields = fields; + } + + /** + * Checks if the database fields are compatible with the given set of fields. + * + * If database fields are null, it is compatible with any input fields + * as it hasn't been generated before. + * + * @param fields The set of input fields to check for compatibility. + * @return true if the database fields are compatible with the given input fields, false otherwise. + */ + public boolean isCompatible(final List fields) { + if (database.fields == null) { + return true; + } + + if (fields.size() < database.fields.size()) { + return false; + } + + Set fieldsSet = new HashSet<>(fields); + for (String field : database.fields) { + if (fieldsSet.contains(field) == false) { + return false; + } + } + return true; + } + + public DatasourceState getState() { + return state; + } + + public List getIndices() { + return indices; + } + + public void setState(DatasourceState previousState) { + this.state = previousState; + } + + public Database getDatabase() { + return this.database; + } + + public UpdateStats getUpdateStats() { + return this.updateStats; + } + + /** + * Database of a datasource + */ + public static class Database implements Writeable, ToXContent { //feedmetadata + private static final ParseField FEED_ID = new ParseField("feed_id"); + private static final ParseField FEED_NAME = new ParseField("feed_name"); + private static final ParseField FEED_FORMAT = new ParseField("feed_format"); + private static final ParseField ENDPOINT_FIELD = new ParseField("endpoint"); + private static final ParseField DESCRIPTION = new ParseField("description"); + private static final ParseField ORGANIZATION = new ParseField("organization"); + private static final ParseField CONTAINED_IOCS_FIELD = new ParseField("contained_iocs_field"); + private static final ParseField IOC_COL = new ParseField("ioc_col"); + private static final ParseField FIELDS_FIELD = new ParseField("fields"); + + /** + * @param feedId id of the feed + * @return id of the feed + */ + private String feedId; + + /** + * @param feedFormat format of the feed (csv, json...) + * @return the type of feed ingested + */ + private String feedFormat; + + /** + * @param endpoint URL of a manifest file + * @return URL of a manifest file + */ + private String endpoint; + + /** + * @param feedName name of the threat intel feed + * @return name of the threat intel feed + */ + private String feedName; + + /** + * @param description description of the threat intel feed + * @return description of the threat intel feed + */ + private String description; + + /** + * @param organization organization of the threat intel feed + * @return organization of the threat intel feed + */ + private String organization; + + /** + * @param contained_iocs_field list of iocs contained in a given feed + * @return list of iocs contained in a given feed + */ + private List contained_iocs_field; + + /** + * @param ioc_col column of the contained ioc + * @return column of the contained ioc + */ + private String iocCol; + + /** + * @param fields A list of available fields in the database + * @return A list of available fields in the database + */ + private List fields; + + public Database(String feedId, String feedName, String feedFormat, final String endpoint, final String description, + final String organization, final List contained_iocs_field, final String iocCol, final List fields) { + this.feedId = feedId; + this.feedName = feedName; + this.feedFormat = feedFormat; + this.endpoint = endpoint; + this.description = description; + this.organization = organization; + this.contained_iocs_field = contained_iocs_field; + this.iocCol = iocCol; + this.fields = fields; + } + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "datasource_metadata_database", + true, + args -> { + String feedId = (String) args[0]; + String feedName = (String) args[1]; + String feedFormat = (String) args[2]; + String endpoint = (String) args[3]; + String description = (String) args[4]; + String organization = (String) args[5]; + List contained_iocs_field = (List) args[6]; + String iocCol = (String) args[7]; + List fields = (List) args[8]; + return new Database(feedFormat, endpoint, feedId, feedName, description, organization, contained_iocs_field, iocCol, fields); + } + ); + static { + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), FEED_ID); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), FEED_NAME); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), FEED_FORMAT); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), ENDPOINT_FIELD); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), DESCRIPTION); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), ORGANIZATION); + PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), CONTAINED_IOCS_FIELD); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), IOC_COL); + PARSER.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), FIELDS_FIELD); + } + + public Database(final StreamInput in) throws IOException { + feedId = in.readString(); + feedName = in.readString(); + feedFormat = in.readString(); + endpoint = in.readString(); + description = in.readString(); + organization = in.readString(); + contained_iocs_field = in.readStringList(); + iocCol = in.readString(); + fields = in.readOptionalStringList(); + } + + private Database(){} + + @Override + public void writeTo(final StreamOutput out) throws IOException { + out.writeString(feedId); + out.writeString(feedName); + out.writeString(feedFormat); + out.writeString(endpoint); + out.writeString(description); + out.writeString(organization); + out.writeStringCollection(contained_iocs_field); + out.writeString(iocCol); + out.writeOptionalStringCollection(fields); + } + + @Override + public XContentBuilder toXContent(final XContentBuilder builder, final Params params) throws IOException { + builder.startObject(); + builder.field(FEED_ID.getPreferredName(), feedId); + builder.field(FEED_NAME.getPreferredName(), feedName); + builder.field(FEED_FORMAT.getPreferredName(), feedFormat); + builder.field(ENDPOINT_FIELD.getPreferredName(), endpoint); + builder.field(DESCRIPTION.getPreferredName(), description); + builder.field(ORGANIZATION.getPreferredName(), organization); + builder.field(CONTAINED_IOCS_FIELD.getPreferredName(), contained_iocs_field); + builder.field(IOC_COL.getPreferredName(), iocCol); + +// if (provider != null) { +// builder.field(PROVIDER_FIELD.getPreferredName(), provider); +// } +// if (updatedAt != null) { +// builder.timeField( +// UPDATED_AT_FIELD.getPreferredName(), +// UPDATED_AT_FIELD_READABLE.getPreferredName(), +// updatedAt.toEpochMilli() +// ); +// } + if (fields != null) { + builder.startArray(FIELDS_FIELD.getPreferredName()); + for (String field : fields) { + builder.value(field); + } + builder.endArray(); + } + builder.endObject(); + return builder; + } + + public String getFeedId() { + return feedId; + } + + public String getFeedFormat() { + return feedFormat; + } + + public String getFeedName() { + return feedName; + } + + public String getDescription() { + return description; + } + + public String getOrganization() { + return organization; + } + + public List getContained_iocs_field() { + return contained_iocs_field; + } + + public String getIocCol() { + return iocCol; + } + + public String getEndpoint() { + return this.endpoint; + } + + public List getFields() { + return fields; + } + public void setFeedId(String feedId) { + this.feedId = feedId; + } + + public void setFeedFormat(String feedFormat) { + this.feedFormat = feedFormat; + } + + public void setEndpoint(String endpoint) { + this.endpoint = endpoint; + } + + public void setFeedName(String feedName) { + this.feedName = feedName; + } + + public void setDescription(String description) { + this.description = description; + } + + public void setOrganization(String organization) { + this.organization = organization; + } + + public void setContained_iocs_field(List contained_iocs_field) { + this.contained_iocs_field = contained_iocs_field; + } + + public void setIocCol(String iocCol) { + this.iocCol = iocCol; + } + + public void setFields(List fields) { + this.fields = fields; + } + + } + + /** + * Update stats of a datasource + */ + public static class UpdateStats implements Writeable, ToXContent { + private static final ParseField LAST_SUCCEEDED_AT_FIELD = new ParseField("last_succeeded_at_in_epoch_millis"); + private static final ParseField LAST_SUCCEEDED_AT_FIELD_READABLE = new ParseField("last_succeeded_at"); + private static final ParseField LAST_PROCESSING_TIME_IN_MILLIS_FIELD = new ParseField("last_processing_time_in_millis"); + private static final ParseField LAST_FAILED_AT_FIELD = new ParseField("last_failed_at_in_epoch_millis"); + private static final ParseField LAST_FAILED_AT_FIELD_READABLE = new ParseField("last_failed_at"); + private static final ParseField LAST_SKIPPED_AT = new ParseField("last_skipped_at_in_epoch_millis"); + private static final ParseField LAST_SKIPPED_AT_READABLE = new ParseField("last_skipped_at"); + + public Instant getLastSucceededAt() { + return lastSucceededAt; + } + + public Long getLastProcessingTimeInMillis() { + return lastProcessingTimeInMillis; + } + + public Instant getLastFailedAt() { + return lastFailedAt; + } + + public Instant getLastSkippedAt() { + return lastSkippedAt; + } + + /** + * @param lastSucceededAt The last time when threat intel feed data update was succeeded + * @return The last time when threat intel feed data update was succeeded + */ + private Instant lastSucceededAt; + /** + * @param lastProcessingTimeInMillis The last processing time when threat intel feed data update was succeeded + * @return The last processing time when threat intel feed data update was succeeded + */ + private Long lastProcessingTimeInMillis; + /** + * @param lastFailedAt The last time when threat intel feed data update was failed + * @return The last time when threat intel feed data update was failed + */ + private Instant lastFailedAt; + + /** + * @param lastSkippedAt The last time when threat intel feed data update was skipped as there was no new update from an endpoint + * @return The last time when threat intel feed data update was skipped as there was no new update from an endpoint + */ + private Instant lastSkippedAt; + + private UpdateStats(){} + + public void setLastSkippedAt(Instant lastSkippedAt) { + this.lastSkippedAt = lastSkippedAt; + } + + public void setLastSucceededAt(Instant lastSucceededAt) { + this.lastSucceededAt = lastSucceededAt; + } + + public void setLastProcessingTimeInMillis(Long lastProcessingTimeInMillis) { + this.lastProcessingTimeInMillis = lastProcessingTimeInMillis; + } + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "datasource_metadata_update_stats", + true, + args -> { + Instant lastSucceededAt = args[0] == null ? null : Instant.ofEpochMilli((long) args[0]); + Long lastProcessingTimeInMillis = (Long) args[1]; + Instant lastFailedAt = args[2] == null ? null : Instant.ofEpochMilli((long) args[2]); + Instant lastSkippedAt = args[3] == null ? null : Instant.ofEpochMilli((long) args[3]); + return new UpdateStats(lastSucceededAt, lastProcessingTimeInMillis, lastFailedAt, lastSkippedAt); + } + ); + static { + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), LAST_SUCCEEDED_AT_FIELD); + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), LAST_PROCESSING_TIME_IN_MILLIS_FIELD); + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), LAST_FAILED_AT_FIELD); + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), LAST_SKIPPED_AT); + } + + public UpdateStats(final StreamInput in) throws IOException { + lastSucceededAt = toInstant(in.readOptionalVLong()); + lastProcessingTimeInMillis = in.readOptionalVLong(); + lastFailedAt = toInstant(in.readOptionalVLong()); + lastSkippedAt = toInstant(in.readOptionalVLong()); + } + + public UpdateStats(Instant lastSucceededAt, Long lastProcessingTimeInMillis, Instant lastFailedAt, Instant lastSkippedAt) { + this.lastSucceededAt = lastSucceededAt; + this.lastProcessingTimeInMillis = lastProcessingTimeInMillis; + this.lastFailedAt = lastFailedAt; + this.lastSkippedAt = lastSkippedAt; + } + + @Override + public void writeTo(final StreamOutput out) throws IOException { + out.writeOptionalVLong(lastSucceededAt == null ? null : lastSucceededAt.toEpochMilli()); + out.writeOptionalVLong(lastProcessingTimeInMillis); + out.writeOptionalVLong(lastFailedAt == null ? null : lastFailedAt.toEpochMilli()); + out.writeOptionalVLong(lastSkippedAt == null ? null : lastSkippedAt.toEpochMilli()); + } + + @Override + public XContentBuilder toXContent(final XContentBuilder builder, final Params params) throws IOException { + builder.startObject(); + if (lastSucceededAt != null) { + builder.timeField( + LAST_SUCCEEDED_AT_FIELD.getPreferredName(), + LAST_SUCCEEDED_AT_FIELD_READABLE.getPreferredName(), + lastSucceededAt.toEpochMilli() + ); + } + if (lastProcessingTimeInMillis != null) { + builder.field(LAST_PROCESSING_TIME_IN_MILLIS_FIELD.getPreferredName(), lastProcessingTimeInMillis); + } + if (lastFailedAt != null) { + builder.timeField( + LAST_FAILED_AT_FIELD.getPreferredName(), + LAST_FAILED_AT_FIELD_READABLE.getPreferredName(), + lastFailedAt.toEpochMilli() + ); + } + if (lastSkippedAt != null) { + builder.timeField( + LAST_SKIPPED_AT.getPreferredName(), + LAST_SKIPPED_AT_READABLE.getPreferredName(), + lastSkippedAt.toEpochMilli() + ); + } + builder.endObject(); + return builder; + } + + public void setLastFailedAt(Instant now) { + this.lastFailedAt = now; + } + } + +// /** +// * Builder class for Datasource +// */ +// public static class Builder { +// public static Datasource build(final PutDatasourceRequest request) { +// String id = request.getName(); +// IntervalSchedule schedule = new IntervalSchedule( +// Instant.now().truncatedTo(ChronoUnit.MILLIS), +// (int) request.getUpdateInterval().days(), +// ChronoUnit.DAYS +// ); +// String feedFormat = request.getFeedFormat(); +// String endpoint = request.getEndpoint(); +// String feedName = request.getFeedName(); +// String description = request.getDescription(); +// String organization = request.getOrganization(); +// List contained_iocs_field = request.getContained_iocs_field(); +// return new Datasource(id, schedule, feedFormat, endpoint, feedName, description, organization, contained_iocs_field); +// } +// } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/threatintel/jobscheduler/DatasourceExtension.java b/src/main/java/org/opensearch/securityanalytics/threatintel/jobscheduler/DatasourceExtension.java new file mode 100644 index 000000000..4d32973e6 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatintel/jobscheduler/DatasourceExtension.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.jobscheduler; + +import org.opensearch.jobscheduler.spi.JobSchedulerExtension; +import org.opensearch.jobscheduler.spi.ScheduledJobParser; +import org.opensearch.jobscheduler.spi.ScheduledJobRunner; + +import java.util.Map; + +public class DatasourceExtension implements JobSchedulerExtension { + /** + * Job index name for a datasource + */ + public static final String JOB_INDEX_NAME = ".scheduler-security_analytics-threatintel-datasource"; //rename this... + + /** + * Job index setting + * + * We want it to be single shard so that job can be run only in a single node by job scheduler. + * We want it to expand to all replicas so that querying to this index can be done locally to reduce latency. + */ + public static final Map INDEX_SETTING = Map.of("index.number_of_shards", 1, "index.number_of_replicas", "0-all", "index.hidden", true); + + @Override + public String getJobType() { + return "scheduler_security_analytics_threatintel_datasource"; + } + + @Override + public String getJobIndex() { + return JOB_INDEX_NAME; + } + + @Override + public ScheduledJobRunner getJobRunner() { + return DatasourceRunner.getJobRunnerInstance(); + } + + @Override + public ScheduledJobParser getJobParser() { + return (parser, id, jobDocVersion) -> Datasource.PARSER.parse(parser, null); + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/threatintel/jobscheduler/DatasourceRunner.java b/src/main/java/org/opensearch/securityanalytics/threatintel/jobscheduler/DatasourceRunner.java new file mode 100644 index 000000000..ee36b355a --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatintel/jobscheduler/DatasourceRunner.java @@ -0,0 +1,160 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.jobscheduler; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.jobscheduler.repackage.com.cronutils.utils.VisibleForTesting; +import org.opensearch.jobscheduler.spi.JobExecutionContext; +import org.opensearch.jobscheduler.spi.LockModel; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.jobscheduler.spi.ScheduledJobRunner; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.securityanalytics.model.DetectorTrigger; + +import java.io.IOException; +import java.time.temporal.ChronoUnit; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; +import java.time.Instant; + +import org.opensearch.securityanalytics.threatIntel.common.DatasourceState; +import org.opensearch.securityanalytics.threatIntel.common.ThreatIntelExecutor; +import org.opensearch.securityanalytics.threatIntel.common.ThreatIntelLockService; +import org.opensearch.securityanalytics.threatIntel.dao.DatasourceDao; +/** + * Datasource update task + * + * This is a background task which is responsible for updating threat intel feed data + */ +public class DatasourceRunner implements ScheduledJobRunner { + private static final Logger log = LogManager.getLogger(DetectorTrigger.class); + private static DatasourceRunner INSTANCE; + + public static DatasourceRunner getJobRunnerInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (DatasourceRunner.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new DatasourceRunner(); + return INSTANCE; + } + } + + private ClusterService clusterService; + + // threat intel specific variables + private DatasourceUpdateService datasourceUpdateService; + private DatasourceDao datasourceDao; + private ThreatIntelExecutor threatIntelExecutor; + private ThreatIntelLockService lockService; + private boolean initialized; + + private DatasourceRunner() { + // Singleton class, use getJobRunner method instead of constructor + } + + public void initialize( + final ClusterService clusterService, + final DatasourceUpdateService datasourceUpdateService, + final DatasourceDao datasourceDao, + final ThreatIntelExecutor threatIntelExecutor, + final ThreatIntelLockService threatIntelLockService + ) { + this.clusterService = clusterService; + this.datasourceUpdateService = datasourceUpdateService; + this.datasourceDao = datasourceDao; + this.threatIntelExecutor = threatIntelExecutor; + this.lockService = threatIntelLockService; + this.initialized = true; + } + + @Override + public void runJob(final ScheduledJobParameter jobParameter, final JobExecutionContext context) { + if (initialized == false) { + throw new AssertionError("this instance is not initialized"); + } + + log.info("Update job started for a datasource[{}]", jobParameter.getName()); + if (jobParameter instanceof Datasource == false) { + log.error("Illegal state exception: job parameter is not instance of Datasource"); + throw new IllegalStateException( + "job parameter is not instance of Datasource, type: " + jobParameter.getClass().getCanonicalName() + ); + } + threatIntelExecutor.forDatasourceUpdate().submit(updateDatasourceRunner(jobParameter)); + } + + /** + * Update threat intel feed data + * + * Lock is used so that only one of nodes run this task. + * + * @param jobParameter job parameter + */ + protected Runnable updateDatasourceRunner(final ScheduledJobParameter jobParameter) { + return () -> { + Optional lockModel = lockService.acquireLock( + jobParameter.getName(), + ThreatIntelLockService.LOCK_DURATION_IN_SECONDS + ); + if (lockModel.isEmpty()) { + log.error("Failed to update. Another processor is holding a lock for datasource[{}]", jobParameter.getName()); + return; + } + + LockModel lock = lockModel.get(); + try { + updateDatasource(jobParameter, lockService.getRenewLockRunnable(new AtomicReference<>(lock))); + } catch (Exception e) { + log.error("Failed to update datasource[{}]", jobParameter.getName(), e); + } finally { + lockService.releaseLock(lock); + } + }; + } + + protected void updateDatasource(final ScheduledJobParameter jobParameter, final Runnable renewLock) throws IOException { + Datasource datasource = datasourceDao.getDatasource(jobParameter.getName()); + /** + * If delete request comes while update task is waiting on a queue for other update tasks to complete, + * because update task for this datasource didn't acquire a lock yet, delete request is processed. + * When it is this datasource's turn to run, it will find that the datasource is deleted already. + * Therefore, we stop the update process when data source does not exist. + */ + if (datasource == null) { + log.info("Datasource[{}] does not exist", jobParameter.getName()); + return; + } + + if (DatasourceState.AVAILABLE.equals(datasource.getState()) == false) { + log.error("Invalid datasource state. Expecting {} but received {}", DatasourceState.AVAILABLE, datasource.getState()); + datasource.disable(); + datasource.getUpdateStats().setLastFailedAt(Instant.now()); + datasourceDao.updateDatasource(datasource); + return; + } + try { + datasourceUpdateService.deleteUnusedIndices(datasource); + if (DatasourceTask.DELETE_UNUSED_INDICES.equals(datasource.getTask()) == false) { + datasourceUpdateService.updateOrCreateThreatIntelFeedData(datasource, renewLock); + } + datasourceUpdateService.deleteUnusedIndices(datasource); + } catch (Exception e) { + log.error("Failed to update datasource for {}", datasource.getName(), e); + datasource.getUpdateStats().setLastFailedAt(Instant.now()); + datasourceDao.updateDatasource(datasource); + } finally { //post processing + datasourceUpdateService.updateDatasource(datasource, datasource.getSchedule(), DatasourceTask.ALL); + } + } + +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/threatintel/jobscheduler/DatasourceTask.java b/src/main/java/org/opensearch/securityanalytics/threatintel/jobscheduler/DatasourceTask.java new file mode 100644 index 000000000..b0e9ac184 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatintel/jobscheduler/DatasourceTask.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.jobscheduler; + +/** + * Task that {@link DatasourceRunner} will run + */ +public enum DatasourceTask { + /** + * Do everything + */ + ALL, + + /** + * Only delete unused indices + */ + DELETE_UNUSED_INDICES +} diff --git a/src/main/java/org/opensearch/securityanalytics/threatintel/jobscheduler/DatasourceUpdateService.java b/src/main/java/org/opensearch/securityanalytics/threatintel/jobscheduler/DatasourceUpdateService.java new file mode 100644 index 000000000..3babb21d3 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/threatintel/jobscheduler/DatasourceUpdateService.java @@ -0,0 +1,296 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.jobscheduler; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.io.IOException; +import java.net.URL; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.UUID; +import java.util.stream.Collectors; + +import org.apache.commons.csv.CSVParser; +import org.apache.commons.csv.CSVRecord; +import org.opensearch.OpenSearchException; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; + +import org.opensearch.core.rest.RestStatus; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.securityanalytics.model.DetectorTrigger; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelFeedParser; +import org.opensearch.securityanalytics.threatIntel.common.DatasourceManifest; +import org.opensearch.securityanalytics.threatIntel.dao.DatasourceDao; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelFeedDataService; +import org.opensearch.securityanalytics.threatIntel.common.DatasourceState; +import org.opensearch.securityanalytics.util.SecurityAnalyticsException; + +public class DatasourceUpdateService { + private static final Logger log = LogManager.getLogger(DetectorTrigger.class); + + private static final int SLEEP_TIME_IN_MILLIS = 5000; // 5 seconds + private static final int MAX_WAIT_TIME_FOR_REPLICATION_TO_COMPLETE_IN_MILLIS = 10 * 60 * 60 * 1000; // 10 hours + private final ClusterService clusterService; + private final ClusterSettings clusterSettings; + private final DatasourceDao datasourceDao; + private final ThreatIntelFeedDataService threatIntelFeedDataService; + + public DatasourceUpdateService( + final ClusterService clusterService, + final DatasourceDao datasourceDao, + final ThreatIntelFeedDataService threatIntelFeedDataService + ) { + this.clusterService = clusterService; + this.clusterSettings = clusterService.getClusterSettings(); + this.datasourceDao = datasourceDao; + this.threatIntelFeedDataService = threatIntelFeedDataService; + } + + // functions used in Datasource Runner + /** + * Delete all indices except the one which are being used + * + * @param datasource + */ + public void deleteUnusedIndices(final Datasource datasource) { + try { + List indicesToDelete = datasource.getIndices() + .stream() + .filter(index -> index.equals(datasource.currentIndexName()) == false) + .collect(Collectors.toList()); + + List deletedIndices = deleteIndices(indicesToDelete); + + if (deletedIndices.isEmpty() == false) { + datasource.getIndices().removeAll(deletedIndices); + datasourceDao.updateDatasource(datasource); + } + } catch (Exception e) { + log.error("Failed to delete old indices for {}", datasource.getName(), e); + } + } + + /** + * Update datasource with given systemSchedule and task + * + * @param datasource datasource to update + * @param systemSchedule new system schedule value + * @param task new task value + */ + public void updateDatasource(final Datasource datasource, final IntervalSchedule systemSchedule, final DatasourceTask task) { + boolean updated = false; + if (datasource.getSchedule().equals(systemSchedule) == false) { + datasource.setSchedule(systemSchedule); + updated = true; + } + if (datasource.getTask().equals(task) == false) { + datasource.setTask(task); + updated = true; + } + if (updated) { + datasourceDao.updateDatasource(datasource); + } + } + + private List deleteIndices(final List indicesToDelete) { + List deletedIndices = new ArrayList<>(indicesToDelete.size()); + for (String index : indicesToDelete) { + if (clusterService.state().metadata().hasIndex(index) == false) { + deletedIndices.add(index); + continue; + } + try { + threatIntelFeedDataService.deleteThreatIntelDataIndex(index); + deletedIndices.add(index); + } catch (Exception e) { + log.error("Failed to delete an index [{}]", index, e); + } + } + return deletedIndices; + } + + + /** + * Update threat intel feed data + * + * The first column is ip range field regardless its header name. + * Therefore, we don't store the first column's header name. + * + * @param datasource the datasource + * @param renewLock runnable to renew lock + * + * @throws IOException + */ + public void updateOrCreateThreatIntelFeedData(final Datasource datasource, final Runnable renewLock) throws IOException { + URL url = new URL(datasource.getDatabase().getEndpoint()); + DatasourceManifest manifest = DatasourceManifest.Builder.build(url); + + Instant startTime = Instant.now(); + String indexName = setupIndex(datasource); + String[] header; + List fieldsToStore; + Boolean succeeded; + + //switch case based on what type of feed + switch(manifest.getFeedType()) { + case "csv": + try (CSVParser reader = ThreatIntelFeedParser.getThreatIntelFeedReaderCSV(manifest)) { + // iterate until we find first line without '#' + CSVRecord findHeader = reader.iterator().next(); + while (findHeader.get(0).charAt(0) == '#' || findHeader.get(0).charAt(0) == ' ') { + findHeader = reader.iterator().next(); + } + CSVRecord headerLine = findHeader; + header = ThreatIntelFeedParser.validateHeader(headerLine).values(); + fieldsToStore = Arrays.asList(header).subList(0, header.length); + if (datasource.isCompatible(fieldsToStore) == false) { + log.error("Exception: new fields does not contain all old fields"); + throw new OpenSearchException( + "new fields [{}] does not contain all old fields [{}]", + fieldsToStore.toString(), + datasource.getDatabase().getFields().toString() + ); + } + threatIntelFeedDataService.saveThreatIntelFeedDataCSV(indexName, header, reader.iterator(), renewLock, manifest); + } + default: + // if the feed type doesn't match any of the supporting feed types, throw an exception + succeeded = false; + fieldsToStore = null; + } + + if (!succeeded) { + log.error("Exception: failed to parse correct feed type"); + throw new OpenSearchException("Exception: failed to parse correct feed type"); + } + + waitUntilAllShardsStarted(indexName, MAX_WAIT_TIME_FOR_REPLICATION_TO_COMPLETE_IN_MILLIS); + Instant endTime = Instant.now(); + updateDatasourceAsSucceeded(indexName, datasource, manifest, fieldsToStore, startTime, endTime); + } + + // helper functions + /*** + * Update datasource as succeeded + * + * @param manifest the manifest + * @param datasource the datasource + */ + private void updateDatasourceAsSucceeded( + final String newIndexName, + final Datasource datasource, + final DatasourceManifest manifest, + final List fields, + final Instant startTime, + final Instant endTime + ) { + datasource.setCurrentIndex(newIndexName); + datasource.setDatabase(manifest, fields); + datasource.getUpdateStats().setLastSucceededAt(endTime); + datasource.getUpdateStats().setLastProcessingTimeInMillis(endTime.toEpochMilli() - startTime.toEpochMilli()); + datasource.enable(); + datasource.setState(DatasourceState.AVAILABLE); + datasourceDao.updateDatasource(datasource); + log.info( + "threat intel feed database creation succeeded for {} and took {} seconds", + datasource.getName(), + Duration.between(startTime, endTime) + ); + } + + /*** + * Setup index to add a new threat intel feed data + * + * @param datasource the datasource + * @return new index name + */ + private String setupIndex(final Datasource datasource) { + String indexName = datasource.newIndexName(UUID.randomUUID().toString()); + datasource.getIndices().add(indexName); + datasourceDao.updateDatasource(datasource); + threatIntelFeedDataService.createIndexIfNotExists(indexName); + return indexName; + } + + /** + * We wait until all shards are ready to serve search requests before updating datasource metadata to + * point to a new index so that there won't be latency degradation during threat intel feed data update + * + * @param indexName the indexName + */ + protected void waitUntilAllShardsStarted(final String indexName, final int timeout) { + Instant start = Instant.now(); + try { + while (Instant.now().toEpochMilli() - start.toEpochMilli() < timeout) { + if (clusterService.state().routingTable().allShards(indexName).stream().allMatch(shard -> shard.started())) { + return; + } + Thread.sleep(SLEEP_TIME_IN_MILLIS); + } + throw new OpenSearchException( + "index[{}] replication did not complete after {} millis", + MAX_WAIT_TIME_FOR_REPLICATION_TO_COMPLETE_IN_MILLIS + ); + } catch (InterruptedException e) { + log.error("runtime exception", e); + throw new SecurityAnalyticsException("Runtime exception", RestStatus.INTERNAL_SERVER_ERROR, e); //TODO + } + } + + + + + + + + /** + * Determine if update is needed or not + * + * Update is needed when all following conditions are met + * 1. updatedAt value in datasource is equal or before updateAt value in manifest + * 2. SHA256 hash value in datasource is different with SHA256 hash value in manifest + * + * @param datasource + * @param manifest + * @return + */ + private boolean shouldUpdate(final Datasource datasource, final DatasourceManifest manifest) { +// if (datasource.getDatabase().getUpdatedAt() != null +// && datasource.getDatabase().getUpdatedAt().toEpochMilli() > manifest.getUpdatedAt()) { +// return false; +// } + +// if (manifest.getSha256Hash().equals(datasource.getDatabase().getSha256Hash())) { +// return false; +// } + return true; + } + + /** + * Return header fields of threat intel feed data with given url of a manifest file + * + * The first column is ip range field regardless its header name. + * Therefore, we don't store the first column's header name. + * + * @param manifestUrl the url of a manifest file + * @return header fields of threat intel feed + */ + public List getHeaderFields(String manifestUrl) throws IOException { + URL url = new URL(manifestUrl); + DatasourceManifest manifest = DatasourceManifest.Builder.build(url); + + try (CSVParser reader = ThreatIntelFeedParser.getThreatIntelFeedReaderCSV(manifest)) { + String[] fields = reader.iterator().next().values(); + return Arrays.asList(fields).subList(1, fields.length); + } + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java index ae2afc1f3..81c548114 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java @@ -96,6 +96,7 @@ import org.opensearch.securityanalytics.rules.backend.QueryBackend; import org.opensearch.securityanalytics.rules.exceptions.SigmaError; import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; +import org.opensearch.securityanalytics.threatIntel.DetectorThreatIntelService; import org.opensearch.securityanalytics.util.DetectorIndices; import org.opensearch.securityanalytics.util.DetectorUtils; import org.opensearch.securityanalytics.util.IndexUtils; @@ -155,6 +156,7 @@ public class TransportIndexDetectorAction extends HandledTransportAction DocLevelQuery docLevelQuery = new DocLevelQuery(id, name, Collections.emptyList(), actualQuery, tags); docLevelQueries.add(docLevelQuery); } + try { + if (detector.getThreatIntelEnabled()) { + DocLevelQuery docLevelQueryFromThreatIntel = detectorThreatIntelService.createDocLevelQueryFromThreatIntel(detector); + docLevelQueries.add(docLevelQueryFromThreatIntel); + } + } catch (Exception e) { + // not failing detector creation if any fatal exception occurs during doc level query creation from threat intel feed data + log.error("Failed to convert threat intel feed to. Proceeding with detector creation", e); + } DocLevelMonitorInput docLevelMonitorInput = new DocLevelMonitorInput(detector.getName(), detector.getInputs().get(0).getIndices(), docLevelQueries); docLevelMonitorInputs.add(docLevelMonitorInput); diff --git a/src/main/resources/META-INF/services/org.opensearch.jobscheduler.spi.JobSchedulerExtension b/src/main/resources/META-INF/services/org.opensearch.jobscheduler.spi.JobSchedulerExtension new file mode 100644 index 000000000..0ffeb24aa --- /dev/null +++ b/src/main/resources/META-INF/services/org.opensearch.jobscheduler.spi.JobSchedulerExtension @@ -0,0 +1 @@ +org.opensearch.securityanalytics.SecurityAnalyticsPlugin \ No newline at end of file diff --git a/src/main/resources/mappings/detectors.json b/src/main/resources/mappings/detectors.json index e1e160d5f..c4a42d53a 100644 --- a/src/main/resources/mappings/detectors.json +++ b/src/main/resources/mappings/detectors.json @@ -62,6 +62,9 @@ "enabled": { "type": "boolean" }, + "threat_intel_enabled": { + "type": "boolean" + }, "enabled_time": { "type": "date", "format": "strict_date_time||epoch_millis" diff --git a/src/main/resources/mappings/threat_intel_datasource_mapping.json b/src/main/resources/mappings/threat_intel_datasource_mapping.json new file mode 100644 index 000000000..5e039928d --- /dev/null +++ b/src/main/resources/mappings/threat_intel_datasource_mapping.json @@ -0,0 +1,118 @@ +{ + "properties": { + "database": { + "properties": { + "feed_id": { + "type": "text" + }, + "feed_name": { + "type": "text" + }, + "feed_format": { + "type": "text" + }, + "endpoint": { + "type": "text" + }, + "description": { + "type": "text" + }, + "organization": { + "type": "text" + }, + "contained_iocs_field": { + "type": "text" + }, + "ioc_col": { + "type": "text" + }, + "fields": { + "type": "text" + } + } + }, + "enabled_time": { + "type": "long" + }, + "indices": { + "type": "text" + }, + "last_update_time": { + "type": "long" + }, + "name": { + "type": "text" + }, + "schedule": { + "properties": { + "interval": { + "properties": { + "period": { + "type": "long" + }, + "start_time": { + "type": "long" + }, + "unit": { + "type": "text" + } + } + } + } + }, + "state": { + "type": "text" + }, + "task": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "update_enabled": { + "type": "boolean" + }, + "update_stats": { + "properties": { + "last_failed_at_in_epoch_millis": { + "type": "long" + }, + "last_processing_time_in_millis": { + "type": "long" + }, + "last_skipped_at_in_epoch_millis": { + "type": "long" + }, + "last_succeeded_at_in_epoch_millis": { + "type": "long" + } + } + }, + "user_schedule": { + "properties": { + "interval": { + "properties": { + "period": { + "type": "long" + }, + "start_time": { + "type": "long" + }, + "unit": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + } + } + } + } + } + } +} \ No newline at end of file diff --git a/src/main/resources/mappings/threat_intel_feed_mapping.json b/src/main/resources/mappings/threat_intel_feed_mapping.json new file mode 100644 index 000000000..2e775cf8e --- /dev/null +++ b/src/main/resources/mappings/threat_intel_feed_mapping.json @@ -0,0 +1,27 @@ +{ + "dynamic": "strict", + "_meta" : { + "schema_version": 1 + }, + "properties": { + "schema_version": { + "type": "integer" + }, + "ioc_type": { + "type": "keyword" + }, + "ioc_value": { + "type": "keyword" + }, + "feed_id": { + "type": "keyword" + }, + "timestamp": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "type": { + "type": "keyword" + } + } +} diff --git a/src/main/resources/threatIntelFeedInfo/feodo.yml b/src/main/resources/threatIntelFeedInfo/feodo.yml new file mode 100644 index 000000000..4acbf40e4 --- /dev/null +++ b/src/main/resources/threatIntelFeedInfo/feodo.yml @@ -0,0 +1,6 @@ +url: "https://feodotracker.abuse.ch/downloads/ipblocklist_aggressive.csv" +name: "ipblocklist_aggressive.csv" +feedFormat: "csv" +org: "Feodo" +iocTypes: ["ip"] +description: "" \ No newline at end of file diff --git a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java index dde7efbb5..a3e73e96f 100644 --- a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java +++ b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java @@ -28,6 +28,7 @@ import org.opensearch.securityanalytics.model.DetectorInput; import org.opensearch.securityanalytics.model.DetectorRule; import org.opensearch.securityanalytics.model.DetectorTrigger; +import org.opensearch.securityanalytics.model.ThreatIntelFeedData; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.rest.OpenSearchRestTestCase; @@ -149,7 +150,7 @@ public static Detector randomDetector(String name, DetectorTrigger trigger = new DetectorTrigger(null, "windows-trigger", "1", List.of(randomDetectorType()), List.of("QuarksPwDump Clearing Access History"), List.of("high"), List.of("T0008"), List.of()); triggers.add(trigger); } - return new Detector(null, null, name, enabled, schedule, lastUpdateTime, enabledTime, detectorType, user, inputs, triggers, Collections.singletonList(""), "", "", "", "", "", "", Collections.emptyMap(), Collections.emptyList()); + return new Detector(null, null, name, enabled, schedule, lastUpdateTime, enabledTime, detectorType, user, inputs, triggers, Collections.singletonList(""), "", "", "", "", "", "", Collections.emptyMap(), Collections.emptyList(), false); } public static CustomLogType randomCustomLogType(String name, String description, String category, String source) { @@ -168,6 +169,15 @@ public static CustomLogType randomCustomLogType(String name, String description, return new CustomLogType(null, null, name, description, category, source, null); } + public static ThreatIntelFeedData randomThreatIntelFeedData() { + return new ThreatIntelFeedData( + "IP_ADDRESS", + "ip", + "alientVault", + Instant.now() + ); + } + public static Detector randomDetectorWithNoUser() { String name = OpenSearchRestTestCase.randomAlphaOfLength(10); String detectorType = randomDetectorType(); @@ -197,7 +207,8 @@ public static Detector randomDetectorWithNoUser() { "", "", Collections.emptyMap(), - Collections.emptyList() + Collections.emptyList(), + false ); } @@ -429,6 +440,12 @@ public static String toJsonStringWithUser(Detector detector) throws IOException return BytesReference.bytes(builder).utf8ToString(); } + public static String toJsonString(ThreatIntelFeedData threatIntelFeedData) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder = threatIntelFeedData.toXContent(builder, ToXContent.EMPTY_PARAMS); + return BytesReference.bytes(builder).utf8ToString(); + } + public static User randomUser() { return new User( OpenSearchRestTestCase.randomAlphaOfLength(10), diff --git a/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java b/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java index db366056b..ca98a1144 100644 --- a/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java +++ b/src/test/java/org/opensearch/securityanalytics/action/IndexDetectorResponseTests.java @@ -50,7 +50,8 @@ public void testIndexDetectorPostResponse() throws IOException { null, DetectorMonitorConfig.getFindingsIndex("others_application"), Collections.emptyMap(), - Collections.emptyList() + Collections.emptyList(), + false ); IndexDetectorResponse response = new IndexDetectorResponse("1234", 1L, RestStatus.OK, detector); Assert.assertNotNull(response); @@ -69,5 +70,6 @@ public void testIndexDetectorPostResponse() throws IOException { Assert.assertTrue(newResponse.getDetector().getMonitorIds().contains("1")); Assert.assertTrue(newResponse.getDetector().getMonitorIds().contains("2")); Assert.assertTrue(newResponse.getDetector().getMonitorIds().contains("3")); + Assert.assertFalse(newResponse.getDetector().getThreatIntelEnabled()); } } \ No newline at end of file diff --git a/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java b/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java index 78dacd6e1..d250d2eef 100644 --- a/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java +++ b/src/test/java/org/opensearch/securityanalytics/alerts/AlertingServiceTests.java @@ -65,7 +65,8 @@ public void testGetAlerts_success() { null, DetectorMonitorConfig.getFindingsIndex("others_application"), Collections.emptyMap(), - Collections.emptyList() + Collections.emptyList(), + false ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); @@ -242,7 +243,8 @@ public void testGetFindings_getFindingsByMonitorIdFailures() { null, DetectorMonitorConfig.getFindingsIndex("others_application"), Collections.emptyMap(), - Collections.emptyList() + Collections.emptyList(), + false ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); diff --git a/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java b/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java index 5c28ba65b..6551f579c 100644 --- a/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java +++ b/src/test/java/org/opensearch/securityanalytics/findings/FindingServiceTests.java @@ -5,6 +5,12 @@ package org.opensearch.securityanalytics.findings; +import java.io.BufferedReader; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStreamReader; +import java.net.URL; +import java.net.URLConnection; import java.time.Instant; import java.time.ZoneId; import java.util.ArrayDeque; @@ -65,7 +71,8 @@ public void testGetFindings_success() { null, DetectorMonitorConfig.getFindingsIndex("others_application"), Collections.emptyMap(), - Collections.emptyList() + Collections.emptyList(), + false ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); @@ -186,7 +193,8 @@ public void testGetFindings_getFindingsByMonitorIdFailure() { null, DetectorMonitorConfig.getFindingsIndex("others_application"), Collections.emptyMap(), - Collections.emptyList() + Collections.emptyList(), + false ); GetDetectorResponse getDetectorResponse = new GetDetectorResponse("detector_id123", 1L, RestStatus.OK, detector); diff --git a/src/test/java/org/opensearch/securityanalytics/model/XContentTests.java b/src/test/java/org/opensearch/securityanalytics/model/XContentTests.java index f2ec8c5cc..89f447440 100644 --- a/src/test/java/org/opensearch/securityanalytics/model/XContentTests.java +++ b/src/test/java/org/opensearch/securityanalytics/model/XContentTests.java @@ -17,8 +17,10 @@ import static org.opensearch.securityanalytics.TestHelpers.parser; import static org.opensearch.securityanalytics.TestHelpers.randomDetector; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithNoUser; +import static org.opensearch.securityanalytics.TestHelpers.randomThreatIntelFeedData; import static org.opensearch.securityanalytics.TestHelpers.randomUser; import static org.opensearch.securityanalytics.TestHelpers.randomUserEmpty; +import static org.opensearch.securityanalytics.TestHelpers.toJsonString; import static org.opensearch.securityanalytics.TestHelpers.toJsonStringWithUser; public class XContentTests extends OpenSearchTestCase { @@ -193,4 +195,12 @@ public void testDetectorParsingWithNoUser() throws IOException { Detector parsedDetector = Detector.parse(parser(detectorString), null, null); Assert.assertEquals("Round tripping Detector doesn't work", detector, parsedDetector); } + + public void testThreatIntelFeedParsing() throws IOException { + ThreatIntelFeedData tifd = randomThreatIntelFeedData(); + + String tifdString = toJsonString(tifd); + ThreatIntelFeedData parsedTifd = ThreatIntelFeedData.parse(parser(tifdString), null, null); + Assert.assertEquals("Round tripping Threat intel feed data model doesn't work", tifd, parsedTifd); + } } \ No newline at end of file diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/IT/JobRunnerIT.java b/src/test/java/org/opensearch/securityanalytics/threatIntel/IT/JobRunnerIT.java new file mode 100644 index 000000000..253f6e59a --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/IT/JobRunnerIT.java @@ -0,0 +1,123 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.securityanalytics.threatIntel.IT; + +import org.junit.Assert; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.Datasource; +import org.opensearch.test.rest.OpenSearchRestTestCase; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; + +public class JobRunnerIT extends ThreatIntelExtensionIntegTestCase { + + public void testJobCreateWithCorrectParams() throws IOException { + Datasource jobParameter = new Datasource(); + jobParameter.setName("sample-job-it"); +// jobParameter.setIndexToWatch("http-logs"); + jobParameter.setSchedule(new IntervalSchedule(Instant.now(), 1, ChronoUnit.MINUTES)); +// jobParameter.setLockDurationSeconds(120L); + + // Creates a new watcher job. + String jobId = OpenSearchRestTestCase.randomAlphaOfLength(10); + Datasource schedJobParameter = createWatcherJob(jobId, jobParameter); + + // Asserts that job is created with correct parameters. + Assert.assertEquals(jobParameter.getName(), schedJobParameter.getName()); +// Assert.assertEquals(jobParameter.getIndexToWatch(), schedJobParameter.getIndexToWatch()); + Assert.assertEquals(jobParameter.getLockDurationSeconds(), schedJobParameter.getLockDurationSeconds()); + } + + public void testJobDeleteWithDescheduleJob() throws Exception { + String index = createTestIndex(); + Datasource jobParameter = new Datasource(); + jobParameter.setName("sample-job-it"); +// jobParameter.setIndexToWatch(index); + jobParameter.setSchedule(new IntervalSchedule(Instant.now(), 1, ChronoUnit.MINUTES)); +// jobParameter.setLockDurationSeconds(120L); + + // Creates a new watcher job. + String jobId = OpenSearchRestTestCase.randomAlphaOfLength(10); + Datasource schedJobParameter = createWatcherJob(jobId, jobParameter); + + // wait till the job runner runs for the first time after 1 min & inserts a record into the watched index & then delete the job. + waitAndDeleteWatcherJob(schedJobParameter.getIndexToWatch(), jobId); + long actualCount = waitAndCountRecords(index, 130000); + + // Asserts that in the last 3 mins, no new job ran to insert a record into the watched index & all locks are deleted for the job. + Assert.assertEquals(1, actualCount); + Assert.assertEquals(0L, getLockTimeByJobId(jobId)); + } + + public void testJobUpdateWithRescheduleJob() throws Exception { + String index = createTestIndex(); + Datasource jobParameter = new Datasource(); + jobParameter.setName("sample-job-it"); +// jobParameter.setIndexToWatch(index); + jobParameter.setSchedule(new IntervalSchedule(Instant.now(), 1, ChronoUnit.MINUTES)); +// jobParameter.setLockDurationSeconds(120L); + + // Creates a new watcher job. + String jobId = OpenSearchRestTestCase.randomAlphaOfLength(10); + Datasource schedJobParameter = createWatcherJob(jobId, jobParameter); + + // update the job params to now watch a new index. + String newIndex = createTestIndex(); +// jobParameter.setIndexToWatch(newIndex); + + // wait till the job runner runs for the first time after 1 min & inserts a record into the watched index & then update the job with + // new params. + waitAndCreateWatcherJob(schedJobParameter.getIndexToWatch(), jobId, jobParameter); + long actualCount = waitAndCountRecords(newIndex, 130000); + + // Asserts that the job runner has the updated params & it inserted the record in the new watched index. + Assert.assertEquals(1, actualCount); + long prevIndexActualCount = waitAndCountRecords(index, 0); + + // Asserts that the job runner no longer updates the old index as the job params have been updated. + Assert.assertEquals(1, prevIndexActualCount); + } + + public void testAcquiredLockPreventExecOfTasks() throws Exception { + String index = createTestIndex(); + Datasource jobParameter = new Datasource(); + jobParameter.setName("sample-job-lock-test-it"); +// jobParameter.setIndexToWatch(index); + // ensures that the next job tries to run even before the previous job finished & released its lock. Also look at + // SampleJobRunner.runTaskForLockIntegrationTests + jobParameter.setSchedule(new IntervalSchedule(Instant.now(), 1, ChronoUnit.MINUTES)); +// jobParameter.setLockDurationSeconds(120L); + + // Creates a new watcher job. + String jobId = OpenSearchRestTestCase.randomAlphaOfLength(10); + createWatcherJob(jobId, jobParameter); + + // Asserts that the job runner is running for the first time & it has inserted a new record into the watched index. + long actualCount = waitAndCountRecords(index, 80000); + Assert.assertEquals(1, actualCount); + + // gets the lock time for the lock acquired for running first job. + long lockTime = getLockTimeByJobId(jobId); + + // Asserts that the second job could not run & hence no new record is inserted into the watched index. + // Also asserts that the old lock acquired for running first job is still not released. + actualCount = waitAndCountRecords(index, 80000); + Assert.assertEquals(1, actualCount); + Assert.assertTrue(doesLockExistByLockTime(lockTime)); + + // Asserts that the new job ran after 2 mins after the first job lock is released. Hence new record is inserted into the watched + // index. + // Also asserts that the old lock is released. + actualCount = waitAndCountRecords(index, 130000); + Assert.assertEquals(2, actualCount); + Assert.assertFalse(doesLockExistByLockTime(lockTime)); + } +} diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/IT/JobSchedulerExtensionPluginIT.java b/src/test/java/org/opensearch/securityanalytics/threatIntel/IT/JobSchedulerExtensionPluginIT.java new file mode 100644 index 000000000..1a5cd9d9c --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/IT/JobSchedulerExtensionPluginIT.java @@ -0,0 +1,40 @@ +package org.opensearch.securityanalytics.threatIntel.IT; + +import org.junit.Assert; +import org.opensearch.action.admin.cluster.health.ClusterHealthRequest; +import org.opensearch.action.admin.cluster.health.ClusterHealthResponse; +import org.opensearch.action.admin.cluster.node.info.NodeInfo; +import org.opensearch.action.admin.cluster.node.info.NodesInfoRequest; +import org.opensearch.action.admin.cluster.node.info.NodesInfoResponse; +import org.opensearch.action.admin.cluster.node.info.PluginsAndModules; +import org.opensearch.cluster.health.ClusterHealthStatus; +import org.opensearch.plugins.PluginInfo; +import org.opensearch.test.OpenSearchIntegTestCase; + +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class JobSchedulerExtensionPluginIT extends OpenSearchIntegTestCase { + + public void testPluginsAreInstalled() { + ClusterHealthRequest request = new ClusterHealthRequest(); + ClusterHealthResponse response = OpenSearchIntegTestCase.client().admin().cluster().health(request).actionGet(); + Assert.assertEquals(ClusterHealthStatus.GREEN, response.getStatus()); + + NodesInfoRequest nodesInfoRequest = new NodesInfoRequest(); + nodesInfoRequest.addMetric(NodesInfoRequest.Metric.PLUGINS.metricName()); + NodesInfoResponse nodesInfoResponse = OpenSearchIntegTestCase.client().admin().cluster().nodesInfo(nodesInfoRequest).actionGet(); + List pluginInfos = nodesInfoResponse.getNodes() + .stream() + .flatMap( + (Function>) nodeInfo -> nodeInfo.getInfo(PluginsAndModules.class).getPluginInfos().stream() + ) + .collect(Collectors.toList()); + Assert.assertTrue(pluginInfos.stream().anyMatch(pluginInfo -> pluginInfo.getName().equals("opensearch-job-scheduler"))); + Assert.assertTrue( + pluginInfos.stream().anyMatch(pluginInfo -> pluginInfo.getName().equals("opensearch-job-scheduler-sample-extension")) + ); + } +} diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/IT/ThreatIntelExtensionIntegTestCase.java b/src/test/java/org/opensearch/securityanalytics/threatIntel/IT/ThreatIntelExtensionIntegTestCase.java new file mode 100644 index 000000000..54e8fb12f --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/IT/ThreatIntelExtensionIntegTestCase.java @@ -0,0 +1,333 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.securityanalytics.threatIntel.IT; + +import org.apache.hc.core5.http.Header; +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.io.entity.StringEntity; +import org.junit.Assert; +import org.opensearch.client.Request; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.Response; +import org.opensearch.client.RestClient; +import org.opensearch.client.WarningsHandler; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.securityanalytics.SecurityAnalyticsPlugin; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.Datasource; +import org.opensearch.test.rest.OpenSearchRestTestCase; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collections; +import java.util.HashMap; +import java.util.Locale; +import java.util.List; +import java.util.Map; +import java.util.Timer; +import java.util.TimerTask; + +public class ThreatIntelExtensionIntegTestCase extends OpenSearchRestTestCase { + +// protected Datasource createWatcherJob(String jobId, Datasource jobParameter) throws IOException { +// return createWatcherJobWithClient(client(), jobId, jobParameter); +// } + +// protected String createWatcherJobJson(String jobId, String jobParameter) throws IOException { +// return createWatcherJobJsonWithClient(client(), jobId, jobParameter); +// } + +// protected Datasource createWatcherJobWithClient(RestClient client, String jobId, Datasource jobParameter) +// throws IOException { +// Map params = getJobParameterAsMap(jobId, jobParameter); +// Response response = makeRequest(client, "POST", SampleExtensionRestHandler.WATCH_INDEX_URI, params, null); +// Assert.assertEquals("Unable to create a watcher job", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); +// +// Map responseJson = JsonXContent.jsonXContent.createParser( +// NamedXContentRegistry.EMPTY, +// LoggingDeprecationHandler.INSTANCE, +// response.getEntity().getContent() +// ).map(); +// return getJobParameter(client, responseJson.get("_id").toString()); +// } + +// protected String createWatcherJobJsonWithClient(RestClient client, String jobId, String jobParameter) throws IOException { +// Response response = makeRequest( +// client, +// "PUT", +// "/" + SampleExtensionPlugin.JOB_INDEX_NAME + "/_doc/" + jobId + "?refresh", +// Collections.emptyMap(), +// new StringEntity(jobParameter, ContentType.APPLICATION_JSON) +// ); +// Assert.assertEquals( +// "Unable to create a watcher job", +// RestStatus.CREATED, +// RestStatus.fromCode(response.getStatusLine().getStatusCode()) +// ); +// +// Map responseJson = JsonXContent.jsonXContent.createParser( +// NamedXContentRegistry.EMPTY, +// LoggingDeprecationHandler.INSTANCE, +// response.getEntity().getContent() +// ).map(); +// return responseJson.get("_id").toString(); +// } + +// protected void deleteWatcherJob(String jobId) throws IOException { +// deleteWatcherJobWithClient(client(), jobId); +// } + +// protected void deleteWatcherJobWithClient(RestClient client, String jobId) throws IOException { +// Response response = makeRequest( +// client, +// "DELETE", +// SampleExtensionRestHandler.WATCH_INDEX_URI, +// Collections.singletonMap("id", jobId), +// null +// ); +// +// Assert.assertEquals("Unable to delete a watcher job", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); +// } + + protected Response makeRequest( + RestClient client, + String method, + String endpoint, + Map params, + HttpEntity entity, + Header... headers + ) throws IOException { + Request request = new Request(method, endpoint); + RequestOptions.Builder options = RequestOptions.DEFAULT.toBuilder(); + options.setWarningsHandler(WarningsHandler.PERMISSIVE); + + for (Header header : headers) { + options.addHeader(header.getName(), header.getValue()); + } + request.setOptions(options.build()); + request.addParameters(params); + if (entity != null) { + request.setEntity(entity); + } + return client.performRequest(request); + } + + protected Map getJobParameterAsMap(String jobId, Datasource jobParameter) throws IOException { + Map params = new HashMap<>(); + params.put("id", jobId); + params.put("job_name", jobParameter.getName()); +// params.put("index", jobParameter.getIndexToWatch()); + params.put("interval", String.valueOf(((IntervalSchedule) jobParameter.getSchedule()).getInterval())); + params.put("lock_duration_seconds", String.valueOf(jobParameter.getLockDurationSeconds())); + return params; + } + +// @SuppressWarnings("unchecked") +// protected Datasource getJobParameter(RestClient client, String jobId) throws IOException { +// Request request = new Request("POST", "/" + SecurityAnalyticsPlugin.JOB_INDEX_NAME + "/_search"); +// String entity = "{\n" +// + " \"query\": {\n" +// + " \"match\": {\n" +// + " \"_id\": {\n" +// + " \"query\": \"" +// + jobId +// + "\"\n" +// + " }\n" +// + " }\n" +// + " }\n" +// + "}"; +// request.setJsonEntity(entity); +// Response response = client.performRequest(request); +// Map responseJson = JsonXContent.jsonXContent.createParser( +// NamedXContentRegistry.EMPTY, +// LoggingDeprecationHandler.INSTANCE, +// response.getEntity().getContent() +// ).map(); +// Map hit = (Map) ((List) ((Map) responseJson.get("hits")).get("hits")).get( +// 0 +// ); +// Map jobSource = (Map) hit.get("_source"); +// +// Datasource jobParameter = new Datasource(); +// jobParameter.setName(jobSource.get("name").toString()); +//// jobParameter.setIndexToWatch(jobSource.get("index_name_to_watch").toString()); +// +// Map jobSchedule = (Map) jobSource.get("schedule"); +// jobParameter.setSchedule( +// new IntervalSchedule( +// Instant.ofEpochMilli(Long.parseLong(((Map) jobSchedule.get("interval")).get("start_time").toString())), +// Integer.parseInt(((Map) jobSchedule.get("interval")).get("period").toString()), +// ChronoUnit.MINUTES +// ) +// ); +//// jobParameter.setLockDurationSeconds(Long.parseLong(jobSource.get("lock_duration_seconds").toString())); +// return jobParameter; +// } + + protected String createTestIndex() throws IOException { + String index = randomAlphaOfLength(10).toLowerCase(Locale.ROOT); + createTestIndex(index); + return index; + } + + protected void createTestIndex(String index) throws IOException { + createIndex(index, Settings.builder().put("index.number_of_shards", 2).put("index.number_of_replicas", 0).build()); + } + + protected void deleteTestIndex(String index) throws IOException { + deleteIndex(index); + } + + protected long countRecordsInTestIndex(String index) throws IOException { + String entity = "{\n" + " \"query\": {\n" + " \"match_all\": {\n" + " }\n" + " }\n" + "}"; + Response response = makeRequest( + client(), + "POST", + "/" + index + "/_count", + Collections.emptyMap(), + new StringEntity(entity, ContentType.APPLICATION_JSON) + ); + Map responseJson = JsonXContent.jsonXContent.createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + response.getEntity().getContent() + ).map(); + return Integer.parseInt(responseJson.get("count").toString()); + } + + protected void waitAndCreateWatcherJob(String jobId, Datasource jobParameter) { + Timer timer = new Timer(); + TimerTask timerTask = new TimerTask() { + private int timeoutCounter = 0; + + @Override + public void run() { + try { + long count = countRecordsInTestIndex(prevIndex); + ++timeoutCounter; + if (count == 1) { + createWatcherJob(jobId, jobParameter); + timer.cancel(); + timer.purge(); + } + if (timeoutCounter >= 24) { + timer.cancel(); + timer.purge(); + } + } catch (IOException ex) { + // do nothing + // suppress exception + } + } + }; + timer.scheduleAtFixedRate(timerTask, 2000, 5000); + } + + protected void waitAndDeleteWatcherJob(String prevIndex, String jobId) { + Timer timer = new Timer(); + TimerTask timerTask = new TimerTask() { + private int timeoutCounter = 0; + + @Override + public void run() { + try { + long count = countRecordsInTestIndex(prevIndex); + ++timeoutCounter; + if (count == 1) { + deleteWatcherJob(jobId); + timer.cancel(); + timer.purge(); + } + if (timeoutCounter >= 24) { + timer.cancel(); + timer.purge(); + } + } catch (IOException ex) { + // do nothing + // suppress exception + } + } + }; + timer.scheduleAtFixedRate(timerTask, 2000, 5000); + } + + protected long waitAndCountRecords(String index, long waitForInMs) throws Exception { + Thread.sleep(waitForInMs); + return countRecordsInTestIndex(index); + } + + @SuppressWarnings("unchecked") + protected long getLockTimeByJobId(String jobId) throws IOException { + String entity = "{\n" + + " \"query\": {\n" + + " \"match\": {\n" + + " \"job_id\": {\n" + + " \"query\": \"" + + jobId + + "\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + Response response = makeRequest( + client(), + "POST", + "/" + ".opendistro-job-scheduler-lock" + "/_search", + Collections.emptyMap(), + new StringEntity(entity, ContentType.APPLICATION_JSON) + ); + Map responseJson = JsonXContent.jsonXContent.createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + response.getEntity().getContent() + ).map(); + List> hits = (List>) ((Map) responseJson.get("hits")).get("hits"); + if (hits.size() == 0) { + return 0L; + } + Map lockSource = (Map) hits.get(0).get("_source"); + return Long.parseLong(lockSource.get("lock_time").toString()); + } + + @SuppressWarnings("unchecked") + protected boolean doesLockExistByLockTime(long lockTime) throws IOException { + String entity = "{\n" + + " \"query\": {\n" + + " \"match\": {\n" + + " \"lock_time\": {\n" + + " \"query\": " + + lockTime + + "\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + Response response = makeRequest( + client(), + "POST", + "/" + ".opendistro-job-scheduler-lock" + "/_search", + Collections.emptyMap(), + new StringEntity(entity, ContentType.APPLICATION_JSON) + ); + Map responseJson = JsonXContent.jsonXContent.createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + response.getEntity().getContent() + ).map(); + List> hits = (List>) ((Map) responseJson.get("hits")).get("hits"); + return hits.size() == 1; + } +} + diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelTestCase.java b/src/test/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelTestCase.java new file mode 100644 index 000000000..fb805007f --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelTestCase.java @@ -0,0 +1,303 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel; + +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +import java.io.File; +import java.net.MalformedURLException; +import java.net.URISyntaxException; +import java.nio.file.Paths; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Locale; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + +import org.junit.After; +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionType; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.routing.RoutingTable; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Randomness; +import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.ingest.IngestMetadata; +import org.opensearch.ingest.IngestService; +import org.opensearch.jobscheduler.spi.LockModel; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.jobscheduler.spi.utils.LockService; +import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; +import org.opensearch.securityanalytics.threatIntel.common.DatasourceState; +import org.opensearch.securityanalytics.threatIntel.common.ThreatIntelExecutor; +import org.opensearch.securityanalytics.threatIntel.common.ThreatIntelLockService; +import org.opensearch.securityanalytics.threatIntel.dao.DatasourceDao; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.DatasourceTask; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.Datasource; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.DatasourceUpdateService; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskListener; +import org.opensearch.test.client.NoOpNodeClient; +import org.opensearch.test.rest.RestActionTestCase; +import org.opensearch.threadpool.ThreadPool; + +public abstract class ThreatIntelTestCase extends RestActionTestCase { + @Mock + protected ClusterService clusterService; + @Mock + protected DatasourceUpdateService datasourceUpdateService; + @Mock + protected DatasourceDao datasourceDao; + @Mock + protected ThreatIntelExecutor threatIntelExecutor; + @Mock + protected ThreatIntelFeedDataService threatIntelFeedDataService; + @Mock + protected ClusterState clusterState; + @Mock + protected Metadata metadata; + @Mock + protected IngestService ingestService; + @Mock + protected ActionFilters actionFilters; + @Mock + protected ThreadPool threadPool; + @Mock + protected ThreatIntelLockService threatIntelLockService; + @Mock + protected RoutingTable routingTable; + protected IngestMetadata ingestMetadata; + protected NoOpNodeClient client; + protected VerifyingClient verifyingClient; + protected LockService lockService; + protected ClusterSettings clusterSettings; + protected Settings settings; + private AutoCloseable openMocks; + + @Before + public void prepareIp2GeoTestCase() { + openMocks = MockitoAnnotations.openMocks(this); + settings = Settings.EMPTY; + client = new NoOpNodeClient(this.getTestName()); + verifyingClient = spy(new VerifyingClient(this.getTestName())); + clusterSettings = new ClusterSettings(settings, new HashSet<>(SecurityAnalyticsSettings.settings())); + lockService = new LockService(client, clusterService); + ingestMetadata = new IngestMetadata(Collections.emptyMap()); + when(metadata.custom(IngestMetadata.TYPE)).thenReturn(ingestMetadata); + when(clusterService.getSettings()).thenReturn(Settings.EMPTY); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(clusterState.getMetadata()).thenReturn(metadata); + when(clusterState.routingTable()).thenReturn(routingTable); + when(ingestService.getClusterService()).thenReturn(clusterService); + when(threadPool.generic()).thenReturn(OpenSearchExecutors.newDirectExecutorService()); + } + + @After + public void clean() throws Exception { + openMocks.close(); + client.close(); + verifyingClient.close(); + } + + protected DatasourceState randomStateExcept(DatasourceState state) { + assertNotNull(state); + return Arrays.stream(DatasourceState.values()) + .sequential() + .filter(s -> !s.equals(state)) + .collect(Collectors.toList()) + .get(Randomness.createSecure().nextInt(DatasourceState.values().length - 2)); + } + + protected DatasourceState randomState() { + return Arrays.stream(DatasourceState.values()) + .sequential() + .collect(Collectors.toList()) + .get(Randomness.createSecure().nextInt(DatasourceState.values().length - 1)); + } + + protected DatasourceTask randomTask() { + return Arrays.stream(DatasourceTask.values()) + .sequential() + .collect(Collectors.toList()) + .get(Randomness.createSecure().nextInt(DatasourceTask.values().length - 1)); + } + + protected String randomIpAddress() { + return String.format( + Locale.ROOT, + "%d.%d.%d.%d", + Randomness.get().nextInt(255), + Randomness.get().nextInt(255), + Randomness.get().nextInt(255), + Randomness.get().nextInt(255) + ); + } + + protected long randomPositiveLong() { + long value = Randomness.get().nextLong(); + return value < 0 ? -value : value; + } + + /** + * Update interval should be > 0 and < validForInDays. + * For an update test to work, there should be at least one eligible value other than current update interval. + * Therefore, the smallest value for validForInDays is 2. + * Update interval is random value from 1 to validForInDays - 2. + * The new update value will be validForInDays - 1. + */ + protected Datasource randomDatasource(final Instant updateStartTime) { + int validForInDays = 3 + Randomness.get().nextInt(30); + Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS); + Datasource datasource = new Datasource(); + datasource.setName(ThreatIntelTestHelper.randomLowerCaseString()); + datasource.setSchedule( + new IntervalSchedule( + updateStartTime.truncatedTo(ChronoUnit.MILLIS), + 1 + Randomness.get().nextInt(validForInDays - 2), + ChronoUnit.DAYS + ) + ); + datasource.setTask(randomTask()); + datasource.setState(randomState()); + datasource.setCurrentIndex(datasource.newIndexName(UUID.randomUUID().toString())); + datasource.setIndices(Arrays.asList(ThreatIntelTestHelper.randomLowerCaseString(), ThreatIntelTestHelper.randomLowerCaseString())); + datasource.getDatabase() + .setFields(Arrays.asList(ThreatIntelTestHelper.randomLowerCaseString(), ThreatIntelTestHelper.randomLowerCaseString())); + datasource.getDatabase().setFeedId(ThreatIntelTestHelper.randomLowerCaseString()); + datasource.getDatabase().setFeedName(ThreatIntelTestHelper.randomLowerCaseString()); + datasource.getDatabase().setFeedFormat(ThreatIntelTestHelper.randomLowerCaseString()); + datasource.getDatabase().setEndpoint(String.format(Locale.ROOT, "https://%s.com/manifest.json", ThreatIntelTestHelper.randomLowerCaseString())); + datasource.getDatabase().setDescription(ThreatIntelTestHelper.randomLowerCaseString()); + datasource.getDatabase().setOrganization(ThreatIntelTestHelper.randomLowerCaseString()); + datasource.getDatabase().setContained_iocs_field(ThreatIntelTestHelper.randomLowerCaseStringList()); + datasource.getDatabase().setIocCol(ThreatIntelTestHelper.randomLowerCaseString()); + datasource.getUpdateStats().setLastSkippedAt(now); + datasource.getUpdateStats().setLastSucceededAt(now); + datasource.getUpdateStats().setLastFailedAt(now); + datasource.getUpdateStats().setLastProcessingTimeInMillis(randomPositiveLong()); + datasource.setLastUpdateTime(now); + if (Randomness.get().nextInt() % 2 == 0) { + datasource.enable(); + } else { + datasource.disable(); + } + return datasource; + } + + protected Datasource randomDatasource() { + return randomDatasource(Instant.now()); + } + + protected LockModel randomLockModel() { + LockModel lockModel = new LockModel( + ThreatIntelTestHelper.randomLowerCaseString(), + ThreatIntelTestHelper.randomLowerCaseString(), + Instant.now(), + randomPositiveLong(), + false + ); + return lockModel; + } + + /** + * Temporary class of VerifyingClient until this PR(https://github.com/opensearch-project/OpenSearch/pull/7167) + * is merged in OpenSearch core + */ + public static class VerifyingClient extends NoOpNodeClient { + AtomicReference executeVerifier = new AtomicReference<>(); + AtomicReference executeLocallyVerifier = new AtomicReference<>(); + + public VerifyingClient(String testName) { + super(testName); + reset(); + } + + /** + * Clears any previously set verifier functions set by {@link #setExecuteVerifier(BiFunction)} and/or + * {@link #setExecuteLocallyVerifier(BiFunction)}. These functions are replaced with functions which will throw an + * {@link AssertionError} if called. + */ + public void reset() { + executeVerifier.set((arg1, arg2) -> { throw new AssertionError(); }); + executeLocallyVerifier.set((arg1, arg2) -> { throw new AssertionError(); }); + } + + /** + * Sets the function that will be called when {@link #doExecute(ActionType, ActionRequest, ActionListener)} is called. The given + * function should return either a subclass of {@link ActionResponse} or {@code null}. + * @param verifier A function which is called in place of {@link #doExecute(ActionType, ActionRequest, ActionListener)} + */ + public void setExecuteVerifier( + BiFunction, Request, Response> verifier + ) { + executeVerifier.set(verifier); + } + + @Override + public void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + try { + listener.onResponse((Response) executeVerifier.get().apply(action, request)); + } catch (Exception e) { + listener.onFailure(e); + } + } + + /** + * Sets the function that will be called when {@link #executeLocally(ActionType, ActionRequest, TaskListener)}is called. The given + * function should return either a subclass of {@link ActionResponse} or {@code null}. + * @param verifier A function which is called in place of {@link #executeLocally(ActionType, ActionRequest, TaskListener)} + */ + public void setExecuteLocallyVerifier( + BiFunction, Request, Response> verifier + ) { + executeLocallyVerifier.set(verifier); + } + + @Override + public Task executeLocally( + ActionType action, + Request request, + ActionListener listener + ) { + listener.onResponse((Response) executeLocallyVerifier.get().apply(action, request)); + return null; + } + + @Override + public Task executeLocally( + ActionType action, + Request request, + TaskListener listener + ) { + listener.onResponse(null, (Response) executeLocallyVerifier.get().apply(action, request)); + return null; + } + + } +} + diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelTestHelper.java b/src/test/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelTestHelper.java new file mode 100644 index 000000000..054710a32 --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/ThreatIntelTestHelper.java @@ -0,0 +1,130 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.securityanalytics.threatIntel; + +import static org.apache.lucene.tests.util.LuceneTestCase.random; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.opensearch.test.OpenSearchTestCase.randomBoolean; +import static org.opensearch.test.OpenSearchTestCase.randomIntBetween; +import static org.opensearch.test.OpenSearchTestCase.randomNonNegativeLong; + +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.stream.IntStream; + + +import org.opensearch.OpenSearchException; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.replication.ReplicationResponse; +import org.opensearch.common.Randomness; +import org.opensearch.common.UUIDs; +import org.opensearch.common.collect.Tuple; +import org.opensearch.core.index.shard.ShardId; + +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.RandomObjects; + +public class ThreatIntelTestHelper { + + public static final int MAX_SEQ_NO = 10000; + public static final int MAX_PRIMARY_TERM = 10000; + public static final int MAX_VERSION = 10000; + public static final int MAX_SHARD_ID = 100; + + public static final int RANDOM_STRING_MIN_LENGTH = 2; + public static final int RANDOM_STRING_MAX_LENGTH = 16; + + private static String randomString() { + return OpenSearchTestCase.randomAlphaOfLengthBetween(RANDOM_STRING_MIN_LENGTH, RANDOM_STRING_MAX_LENGTH); + } + + public static String randomLowerCaseString() { + return randomString().toLowerCase(Locale.ROOT); + } + + public static List randomLowerCaseStringList() { + List stringList = new ArrayList<>(); + stringList.add(randomLowerCaseString()); + return stringList; + } + + /** + * Returns random {@link IndexResponse} by generating inputs using random functions. + * It is not guaranteed to generate every possible values, and it is not required since + * it is used by the unit test and will not be validated by the cluster. + */ + private static IndexResponse randomIndexResponse() { + String index = randomLowerCaseString(); + String indexUUid = UUIDs.randomBase64UUID(); + int shardId = randomIntBetween(0, MAX_SHARD_ID); + String id = UUIDs.randomBase64UUID(); + long seqNo = randomIntBetween(0, MAX_SEQ_NO); + long primaryTerm = randomIntBetween(0, MAX_PRIMARY_TERM); + long version = randomIntBetween(0, MAX_VERSION); + boolean created = randomBoolean(); + boolean forcedRefresh = randomBoolean(); + Tuple shardInfo = RandomObjects.randomShardInfo(random()); + IndexResponse actual = new IndexResponse(new ShardId(index, indexUUid, shardId), id, seqNo, primaryTerm, version, created); + actual.setForcedRefresh(forcedRefresh); + actual.setShardInfo(shardInfo.v1()); + + return actual; + } + + // Generate Random Bulk Response with noOfSuccessItems as BulkItemResponse, and include BulkItemResponse.Failure with + // random error message, if hasFailures is true. + public static BulkResponse generateRandomBulkResponse(int noOfSuccessItems, boolean hasFailures) { + long took = randomNonNegativeLong(); + long ingestTook = randomNonNegativeLong(); + if (noOfSuccessItems < 1) { + return new BulkResponse(null, took, ingestTook); + } + List items = new ArrayList<>(); + IntStream.range(0, noOfSuccessItems) + .forEach(shardId -> items.add(new BulkItemResponse(shardId, DocWriteRequest.OpType.CREATE, randomIndexResponse()))); + if (hasFailures) { + final BulkItemResponse.Failure failedToIndex = new BulkItemResponse.Failure( + randomLowerCaseString(), + randomLowerCaseString(), + new OpenSearchException(randomLowerCaseString()) + ); + items.add(new BulkItemResponse(randomIntBetween(0, MAX_SHARD_ID), DocWriteRequest.OpType.CREATE, failedToIndex)); + } + return new BulkResponse(items.toArray(BulkItemResponse[]::new), took, ingestTook); + } + + public static StringBuilder buildFieldNameValuePair(Object field, Object value) { + StringBuilder builder = new StringBuilder(); + builder.append("\"").append(field).append("\":"); + if (!(value instanceof String)) { + return builder.append(value); + } + return builder.append("\"").append(value).append("\""); + } + + public static String removeStartAndEndObject(String content) { + assertNotNull(content); + assertTrue("content length should be at least 2", content.length() > 1); + return content.substring(1, content.length() - 1); + } + + public static double[] toDoubleArray(float[] input) { + return IntStream.range(0, input.length).mapToDouble(i -> input[i]).toArray(); + } + +} + diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/common/DatasourceManifestTests.java b/src/test/java/org/opensearch/securityanalytics/threatIntel/common/DatasourceManifestTests.java new file mode 100644 index 000000000..d98eccbf3 --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/common/DatasourceManifestTests.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.threatIntel.common; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.net.URLConnection; + +import org.opensearch.common.SuppressForbidden; +import org.opensearch.securityanalytics.SecurityAnalyticsRestTestCase; + +@SuppressForbidden(reason = "unit test") +public class DatasourceManifestTests extends SecurityAnalyticsRestTestCase { + + public void testInternalBuild_whenCalled_thenCorrectUserAgentValueIsSet() throws IOException { + URLConnection connection = mock(URLConnection.class); + File manifestFile = new File(this.getClass().getClassLoader().getResource("threatIntel/manifest.json").getFile()); + when(connection.getInputStream()).thenReturn(new FileInputStream(manifestFile)); + + // Run + DatasourceManifest manifest = DatasourceManifest.Builder.internalBuild(connection); + + // Verify + verify(connection).addRequestProperty(Constants.USER_AGENT_KEY, Constants.USER_AGENT_VALUE); + assertEquals("https://test.com/db.zip", manifest.getUrl()); + } +} + diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/common/ThreatIntelLockServiceTests.java b/src/test/java/org/opensearch/securityanalytics/threatIntel/common/ThreatIntelLockServiceTests.java new file mode 100644 index 000000000..c20ec6a5f --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/common/ThreatIntelLockServiceTests.java @@ -0,0 +1,117 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.common; + +import static org.mockito.Mockito.mock; +import static org.opensearch.securityanalytics.threatIntel.common.ThreatIntelLockService.LOCK_DURATION_IN_SECONDS; +import static org.opensearch.securityanalytics.threatIntel.common.ThreatIntelLockService.RENEW_AFTER_IN_SECONDS; + +import java.time.Instant; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.Before; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.jobscheduler.spi.LockModel; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelTestCase; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelTestHelper; + +public class ThreatIntelLockServiceTests extends ThreatIntelTestCase { + private ThreatIntelLockService ip2GeoLockService; + private ThreatIntelLockService noOpsLockService; + + @Before + public void init() { + ip2GeoLockService = new ThreatIntelLockService(clusterService, verifyingClient); + noOpsLockService = new ThreatIntelLockService(clusterService, client); + } + + public void testAcquireLock_whenValidInput_thenSucceed() { + // Cannot test because LockService is final class + // Simply calling method to increase coverage + noOpsLockService.acquireLock(ThreatIntelTestHelper.randomLowerCaseString(), randomPositiveLong(), mock(ActionListener.class)); + } + + public void testAcquireLock_whenCalled_thenNotBlocked() { + long expectedDurationInMillis = 1000; + Instant before = Instant.now(); + assertTrue(ip2GeoLockService.acquireLock(null, null).isEmpty()); + Instant after = Instant.now(); + assertTrue(after.toEpochMilli() - before.toEpochMilli() < expectedDurationInMillis); + } + + public void testReleaseLock_whenValidInput_thenSucceed() { + // Cannot test because LockService is final class + // Simply calling method to increase coverage + noOpsLockService.releaseLock(null); + } + + public void testRenewLock_whenCalled_thenNotBlocked() { + long expectedDurationInMillis = 1000; + Instant before = Instant.now(); + assertNull(ip2GeoLockService.renewLock(null)); + Instant after = Instant.now(); + assertTrue(after.toEpochMilli() - before.toEpochMilli() < expectedDurationInMillis); + } + + public void testGetRenewLockRunnable_whenLockIsFresh_thenDoNotRenew() { + LockModel lockModel = new LockModel( + ThreatIntelTestHelper.randomLowerCaseString(), + ThreatIntelTestHelper.randomLowerCaseString(), + Instant.now(), + LOCK_DURATION_IN_SECONDS, + false + ); + + verifyingClient.setExecuteVerifier((actionResponse, actionRequest) -> { + // Verifying + assertTrue(actionRequest instanceof UpdateRequest); + return new UpdateResponse( + mock(ShardId.class), + ThreatIntelTestHelper.randomLowerCaseString(), + randomPositiveLong(), + randomPositiveLong(), + randomPositiveLong(), + DocWriteResponse.Result.UPDATED + ); + }); + + AtomicReference reference = new AtomicReference<>(lockModel); + ip2GeoLockService.getRenewLockRunnable(reference).run(); + assertEquals(lockModel, reference.get()); + } + + public void testGetRenewLockRunnable_whenLockIsStale_thenRenew() { + LockModel lockModel = new LockModel( + ThreatIntelTestHelper.randomLowerCaseString(), + ThreatIntelTestHelper.randomLowerCaseString(), + Instant.now().minusSeconds(RENEW_AFTER_IN_SECONDS), + LOCK_DURATION_IN_SECONDS, + false + ); + + verifyingClient.setExecuteVerifier((actionResponse, actionRequest) -> { + // Verifying + assertTrue(actionRequest instanceof UpdateRequest); + return new UpdateResponse( + mock(ShardId.class), + ThreatIntelTestHelper.randomLowerCaseString(), + randomPositiveLong(), + randomPositiveLong(), + randomPositiveLong(), + DocWriteResponse.Result.UPDATED + ); + }); + + AtomicReference reference = new AtomicReference<>(lockModel); + ip2GeoLockService.getRenewLockRunnable(reference).run(); + assertNotEquals(lockModel, reference.get()); + } +} + diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/dao/DatasourceDaoTests.java b/src/test/java/org/opensearch/securityanalytics/threatIntel/dao/DatasourceDaoTests.java new file mode 100644 index 000000000..afbb203ec --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/dao/DatasourceDaoTests.java @@ -0,0 +1,388 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.dao; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.List; + +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.StepListener; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.get.MultiGetItemResponse; +import org.opensearch.action.get.MultiGetRequest; +import org.opensearch.action.get.MultiGetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.cluster.routing.Preference; +import org.opensearch.common.Randomness; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelTestCase; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelTestHelper; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.Datasource; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.DatasourceExtension; + +public class DatasourceDaoTests extends ThreatIntelTestCase { + private DatasourceDao datasourceDao; + + @Before + public void init() { + datasourceDao = new DatasourceDao(verifyingClient, clusterService); + } + + public void testCreateIndexIfNotExists_whenIndexExist_thenCreateRequestIsNotCalled() { + when(metadata.hasIndex(DatasourceExtension.JOB_INDEX_NAME)).thenReturn(true); + + // Verify + verifyingClient.setExecuteVerifier((actionResponse, actionRequest) -> { throw new RuntimeException("Shouldn't get called"); }); + + // Run + StepListener stepListener = new StepListener<>(); + datasourceDao.createIndexIfNotExists(stepListener); + + // Verify stepListener is called + stepListener.result(); + } + + public void testCreateIndexIfNotExists_whenIndexExist_thenCreateRequestIsCalled() { + when(metadata.hasIndex(DatasourceExtension.JOB_INDEX_NAME)).thenReturn(false); + + // Verify + verifyingClient.setExecuteVerifier((actionResponse, actionRequest) -> { + assertTrue(actionRequest instanceof CreateIndexRequest); + CreateIndexRequest request = (CreateIndexRequest) actionRequest; + assertEquals(DatasourceExtension.JOB_INDEX_NAME, request.index()); + assertEquals("1", request.settings().get("index.number_of_shards")); + assertEquals("0-all", request.settings().get("index.auto_expand_replicas")); + assertEquals("true", request.settings().get("index.hidden")); + assertNotNull(request.mappings()); + return null; + }); + + // Run + StepListener stepListener = new StepListener<>(); + datasourceDao.createIndexIfNotExists(stepListener); + + // Verify stepListener is called + stepListener.result(); + } + + public void testCreateIndexIfNotExists_whenIndexCreatedAlready_thenExceptionIsIgnored() { + when(metadata.hasIndex(DatasourceExtension.JOB_INDEX_NAME)).thenReturn(false); + verifyingClient.setExecuteVerifier( + (actionResponse, actionRequest) -> { throw new ResourceAlreadyExistsException(DatasourceExtension.JOB_INDEX_NAME); } + ); + + // Run + StepListener stepListener = new StepListener<>(); + datasourceDao.createIndexIfNotExists(stepListener); + + // Verify stepListener is called + stepListener.result(); + } + + public void testCreateIndexIfNotExists_whenExceptionIsThrown_thenExceptionIsThrown() { + when(metadata.hasIndex(DatasourceExtension.JOB_INDEX_NAME)).thenReturn(false); + verifyingClient.setExecuteVerifier((actionResponse, actionRequest) -> { throw new RuntimeException(); }); + + // Run + StepListener stepListener = new StepListener<>(); + datasourceDao.createIndexIfNotExists(stepListener); + + // Verify stepListener is called + expectThrows(RuntimeException.class, () -> stepListener.result()); + } + + public void testUpdateDatasource_whenValidInput_thenSucceed() throws Exception { + String datasourceName = ThreatIntelTestHelper.randomLowerCaseString(); + Datasource datasource = new Datasource( + datasourceName, + new IntervalSchedule(Instant.now().truncatedTo(ChronoUnit.MILLIS), 1, ChronoUnit.DAYS) + ); + Instant previousTime = Instant.now().minusMillis(1); + datasource.setLastUpdateTime(previousTime); + + verifyingClient.setExecuteVerifier((actionResponse, actionRequest) -> { + assertTrue(actionRequest instanceof IndexRequest); + IndexRequest request = (IndexRequest) actionRequest; + assertEquals(datasource.getName(), request.id()); + assertEquals(DocWriteRequest.OpType.INDEX, request.opType()); + assertEquals(DatasourceExtension.JOB_INDEX_NAME, request.index()); + assertEquals(WriteRequest.RefreshPolicy.IMMEDIATE, request.getRefreshPolicy()); + return null; + }); + + datasourceDao.updateDatasource(datasource); + assertTrue(previousTime.isBefore(datasource.getLastUpdateTime())); + } + + public void testPutDatasource_whenValidInpu_thenSucceed() { + Datasource datasource = randomDatasource(); + Instant previousTime = Instant.now().minusMillis(1); + datasource.setLastUpdateTime(previousTime); + + verifyingClient.setExecuteVerifier((actionResponse, actionRequest) -> { + assertTrue(actionRequest instanceof IndexRequest); + IndexRequest indexRequest = (IndexRequest) actionRequest; + assertEquals(DatasourceExtension.JOB_INDEX_NAME, indexRequest.index()); + assertEquals(datasource.getName(), indexRequest.id()); + assertEquals(WriteRequest.RefreshPolicy.IMMEDIATE, indexRequest.getRefreshPolicy()); + assertEquals(DocWriteRequest.OpType.CREATE, indexRequest.opType()); + return null; + }); + + datasourceDao.putDatasource(datasource, mock(ActionListener.class)); + assertTrue(previousTime.isBefore(datasource.getLastUpdateTime())); + } + + public void testGetDatasource_whenException_thenNull() throws Exception { + Datasource datasource = setupClientForGetRequest(true, new IndexNotFoundException(DatasourceExtension.JOB_INDEX_NAME)); + assertNull(datasourceDao.getDatasource(datasource.getName())); + } + + public void testGetDatasource_whenExist_thenReturnDatasource() throws Exception { + Datasource datasource = setupClientForGetRequest(true, null); + assertEquals(datasource, datasourceDao.getDatasource(datasource.getName())); + } + + public void testGetDatasource_whenNotExist_thenNull() throws Exception { + Datasource datasource = setupClientForGetRequest(false, null); + assertNull(datasourceDao.getDatasource(datasource.getName())); + } + + public void testGetDatasource_whenExistWithListener_thenListenerIsCalledWithDatasource() { + Datasource datasource = setupClientForGetRequest(true, null); + ActionListener listener = mock(ActionListener.class); + datasourceDao.getDatasource(datasource.getName(), listener); + verify(listener).onResponse(eq(datasource)); + } + + public void testGetDatasource_whenNotExistWithListener_thenListenerIsCalledWithNull() { + Datasource datasource = setupClientForGetRequest(false, null); + ActionListener listener = mock(ActionListener.class); + datasourceDao.getDatasource(datasource.getName(), listener); + verify(listener).onResponse(null); + } + + private Datasource setupClientForGetRequest(final boolean isExist, final RuntimeException exception) { + Datasource datasource = randomDatasource(); + + verifyingClient.setExecuteVerifier((actionResponse, actionRequest) -> { + assertTrue(actionRequest instanceof GetRequest); + GetRequest request = (GetRequest) actionRequest; + assertEquals(datasource.getName(), request.id()); + assertEquals(DatasourceExtension.JOB_INDEX_NAME, request.index()); + GetResponse response = getMockedGetResponse(isExist ? datasource : null); + if (exception != null) { + throw exception; + } + return response; + }); + return datasource; + } + + public void testDeleteDatasource_whenValidInput_thenSucceed() { + Datasource datasource = randomDatasource(); + verifyingClient.setExecuteVerifier((actionResponse, actionRequest) -> { + // Verify + assertTrue(actionRequest instanceof DeleteRequest); + DeleteRequest request = (DeleteRequest) actionRequest; + assertEquals(DatasourceExtension.JOB_INDEX_NAME, request.index()); + assertEquals(DocWriteRequest.OpType.DELETE, request.opType()); + assertEquals(datasource.getName(), request.id()); + assertEquals(WriteRequest.RefreshPolicy.IMMEDIATE, request.getRefreshPolicy()); + + DeleteResponse response = mock(DeleteResponse.class); + when(response.status()).thenReturn(RestStatus.OK); + return response; + }); + + // Run + datasourceDao.deleteDatasource(datasource); + } + + public void testDeleteDatasource_whenIndexNotFound_thenThrowException() { + Datasource datasource = randomDatasource(); + verifyingClient.setExecuteVerifier((actionResponse, actionRequest) -> { + DeleteResponse response = mock(DeleteResponse.class); + when(response.status()).thenReturn(RestStatus.NOT_FOUND); + return response; + }); + + // Run + expectThrows(ResourceNotFoundException.class, () -> datasourceDao.deleteDatasource(datasource)); + } + + public void testGetDatasources_whenValidInput_thenSucceed() { + List datasources = Arrays.asList(randomDatasource(), randomDatasource()); + String[] names = datasources.stream().map(Datasource::getName).toArray(String[]::new); + ActionListener> listener = mock(ActionListener.class); + MultiGetItemResponse[] multiGetItemResponses = datasources.stream().map(datasource -> { + GetResponse getResponse = getMockedGetResponse(datasource); + MultiGetItemResponse multiGetItemResponse = mock(MultiGetItemResponse.class); + when(multiGetItemResponse.getResponse()).thenReturn(getResponse); + return multiGetItemResponse; + }).toArray(MultiGetItemResponse[]::new); + + verifyingClient.setExecuteVerifier((actionResponse, actionRequest) -> { + // Verify + assertTrue(actionRequest instanceof MultiGetRequest); + MultiGetRequest request = (MultiGetRequest) actionRequest; + assertEquals(2, request.getItems().size()); + for (MultiGetRequest.Item item : request.getItems()) { + assertEquals(DatasourceExtension.JOB_INDEX_NAME, item.index()); + assertTrue(datasources.stream().filter(datasource -> datasource.getName().equals(item.id())).findAny().isPresent()); + } + + MultiGetResponse response = mock(MultiGetResponse.class); + when(response.getResponses()).thenReturn(multiGetItemResponses); + return response; + }); + + // Run + datasourceDao.getDatasources(names, listener); + + // Verify + ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); + verify(listener).onResponse(captor.capture()); + assertEquals(datasources, captor.getValue()); + + } + + public void testGetAllDatasources_whenAsynchronous_thenSucceed() { + List datasources = Arrays.asList(randomDatasource(), randomDatasource()); + ActionListener> listener = mock(ActionListener.class); + SearchHits searchHits = getMockedSearchHits(datasources); + + verifyingClient.setExecuteVerifier((actionResponse, actionRequest) -> { + // Verify + assertTrue(actionRequest instanceof SearchRequest); + SearchRequest request = (SearchRequest) actionRequest; + assertEquals(1, request.indices().length); + assertEquals(DatasourceExtension.JOB_INDEX_NAME, request.indices()[0]); + assertEquals(QueryBuilders.matchAllQuery(), request.source().query()); + assertEquals(1000, request.source().size()); + assertEquals(Preference.PRIMARY.type(), request.preference()); + + SearchResponse response = mock(SearchResponse.class); + when(response.getHits()).thenReturn(searchHits); + return response; + }); + + // Run + datasourceDao.getAllDatasources(listener); + + // Verify + ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); + verify(listener).onResponse(captor.capture()); + assertEquals(datasources, captor.getValue()); + } + + public void testGetAllDatasources_whenSynchronous_thenSucceed() { + List datasources = Arrays.asList(randomDatasource(), randomDatasource()); + SearchHits searchHits = getMockedSearchHits(datasources); + + verifyingClient.setExecuteVerifier((actionResponse, actionRequest) -> { + // Verify + assertTrue(actionRequest instanceof SearchRequest); + SearchRequest request = (SearchRequest) actionRequest; + assertEquals(1, request.indices().length); + assertEquals(DatasourceExtension.JOB_INDEX_NAME, request.indices()[0]); + assertEquals(QueryBuilders.matchAllQuery(), request.source().query()); + assertEquals(1000, request.source().size()); + assertEquals(Preference.PRIMARY.type(), request.preference()); + + SearchResponse response = mock(SearchResponse.class); + when(response.getHits()).thenReturn(searchHits); + return response; + }); + + // Run + datasourceDao.getAllDatasources(); + + // Verify + assertEquals(datasources, datasourceDao.getAllDatasources()); + } + + public void testUpdateDatasource_whenValidInput_thenUpdate() { + List datasources = Arrays.asList(randomDatasource(), randomDatasource()); + + verifyingClient.setExecuteVerifier((actionResponse, actionRequest) -> { + // Verify + assertTrue(actionRequest instanceof BulkRequest); + BulkRequest bulkRequest = (BulkRequest) actionRequest; + assertEquals(2, bulkRequest.requests().size()); + for (int i = 0; i < bulkRequest.requests().size(); i++) { + IndexRequest request = (IndexRequest) bulkRequest.requests().get(i); + assertEquals(DatasourceExtension.JOB_INDEX_NAME, request.index()); + assertEquals(datasources.get(i).getName(), request.id()); + assertEquals(DocWriteRequest.OpType.INDEX, request.opType()); +// assertTrue(request.source().utf8ToString().contains(datasources.get(i).getEndpoint())); + } + return null; + }); + + datasourceDao.updateDatasource(datasources, mock(ActionListener.class)); + } + + private SearchHits getMockedSearchHits(List datasources) { + SearchHit[] searchHitArray = datasources.stream().map(this::toBytesReference).map(this::toSearchHit).toArray(SearchHit[]::new); + + return new SearchHits(searchHitArray, new TotalHits(1l, TotalHits.Relation.EQUAL_TO), 1); + } + + private GetResponse getMockedGetResponse(Datasource datasource) { + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(datasource != null); + when(response.getSourceAsBytesRef()).thenReturn(toBytesReference(datasource)); + return response; + } + + private BytesReference toBytesReference(Datasource datasource) { + if (datasource == null) { + return null; + } + + try { + return BytesReference.bytes(datasource.toXContent(JsonXContent.contentBuilder(), null)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private SearchHit toSearchHit(BytesReference bytesReference) { + SearchHit searchHit = new SearchHit(Randomness.get().nextInt()); + searchHit.sourceRef(bytesReference); + return searchHit; + } +} diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/DatasourceExtensionTests.java b/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/DatasourceExtensionTests.java new file mode 100644 index 000000000..11b3edf9d --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/DatasourceExtensionTests.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.jobscheduler; + +import static org.opensearch.securityanalytics.threatIntel.jobscheduler.DatasourceExtension.JOB_INDEX_NAME; + +import java.time.Instant; +import java.time.temporal.ChronoUnit; + +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.jobscheduler.spi.JobDocVersion; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelTestCase; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelTestHelper; +public class DatasourceExtensionTests extends ThreatIntelTestCase { + public void testBasic() { + DatasourceExtension extension = new DatasourceExtension(); + assertEquals("scheduler_security_analytics_threatintel_datasource", extension.getJobType()); + assertEquals(JOB_INDEX_NAME, extension.getJobIndex()); + assertEquals(DatasourceRunner.getJobRunnerInstance(), extension.getJobRunner()); + } + + public void testParser() throws Exception { + DatasourceExtension extension = new DatasourceExtension(); + String id = ThreatIntelTestHelper.randomLowerCaseString(); + IntervalSchedule schedule = new IntervalSchedule(Instant.now().truncatedTo(ChronoUnit.MILLIS), 1, ChronoUnit.DAYS); + String endpoint = ThreatIntelTestHelper.randomLowerCaseString(); + Datasource datasource = new Datasource(id, schedule); + + Datasource anotherDatasource = (Datasource) extension.getJobParser() + .parse( + createParser(datasource.toXContent(XContentFactory.jsonBuilder(), null)), + ThreatIntelTestHelper.randomLowerCaseString(), + new JobDocVersion(randomPositiveLong(), randomPositiveLong(), randomPositiveLong()) + ); + + assertTrue(datasource.equals(anotherDatasource)); + } +} diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/DatasourceRunnerTests.java b/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/DatasourceRunnerTests.java new file mode 100644 index 000000000..ea51dd5b0 --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/DatasourceRunnerTests.java @@ -0,0 +1,177 @@ + +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.jobscheduler; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.mockito.internal.verification.VerificationModeFactory.times; + +import java.io.IOException; +import java.time.Instant; +import java.util.Optional; + +import org.junit.Before; + +import org.opensearch.jobscheduler.spi.JobDocVersion; +import org.opensearch.jobscheduler.spi.JobExecutionContext; +import org.opensearch.jobscheduler.spi.LockModel; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelTestCase; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelTestHelper; +import org.opensearch.securityanalytics.threatIntel.common.DatasourceState; +import org.opensearch.securityanalytics.threatIntel.common.ThreatIntelLockService; + +public class DatasourceRunnerTests extends ThreatIntelTestCase { + @Before + public void init() { + DatasourceRunner.getJobRunnerInstance() + .initialize(clusterService, datasourceUpdateService, datasourceDao, threatIntelExecutor, threatIntelLockService); + } + + public void testGetJobRunnerInstance_whenCalledAgain_thenReturnSameInstance() { + assertTrue(DatasourceRunner.getJobRunnerInstance() == DatasourceRunner.getJobRunnerInstance()); + } + + public void testRunJob_whenInvalidClass_thenThrowException() { + JobDocVersion jobDocVersion = new JobDocVersion(randomInt(), randomInt(), randomInt()); + String jobIndexName = ThreatIntelTestHelper.randomLowerCaseString(); + String jobId = ThreatIntelTestHelper.randomLowerCaseString(); + JobExecutionContext jobExecutionContext = new JobExecutionContext(Instant.now(), jobDocVersion, lockService, jobIndexName, jobId); + ScheduledJobParameter jobParameter = mock(ScheduledJobParameter.class); + + // Run + expectThrows(IllegalStateException.class, () -> DatasourceRunner.getJobRunnerInstance().runJob(jobParameter, jobExecutionContext)); + } + + public void testRunJob_whenValidInput_thenSucceed() throws IOException { + JobDocVersion jobDocVersion = new JobDocVersion(randomInt(), randomInt(), randomInt()); + String jobIndexName = ThreatIntelTestHelper.randomLowerCaseString(); + String jobId = ThreatIntelTestHelper.randomLowerCaseString(); + JobExecutionContext jobExecutionContext = new JobExecutionContext(Instant.now(), jobDocVersion, lockService, jobIndexName, jobId); + Datasource datasource = randomDatasource(); + + LockModel lockModel = randomLockModel(); + when(threatIntelLockService.acquireLock(datasource.getName(), ThreatIntelLockService.LOCK_DURATION_IN_SECONDS)).thenReturn( + Optional.of(lockModel) + ); + + // Run + DatasourceRunner.getJobRunnerInstance().runJob(datasource, jobExecutionContext); + + // Verify + verify(threatIntelLockService).acquireLock(datasource.getName(), threatIntelLockService.LOCK_DURATION_IN_SECONDS); + verify(datasourceDao).getDatasource(datasource.getName()); + verify(threatIntelLockService).releaseLock(lockModel); + } + + public void testUpdateDatasourceRunner_whenExceptionBeforeAcquiringLock_thenNoReleaseLock() { + ScheduledJobParameter jobParameter = mock(ScheduledJobParameter.class); + when(jobParameter.getName()).thenReturn(ThreatIntelTestHelper.randomLowerCaseString()); + when(threatIntelLockService.acquireLock(jobParameter.getName(), ThreatIntelLockService.LOCK_DURATION_IN_SECONDS)).thenThrow( + new RuntimeException() + ); + + // Run + expectThrows(Exception.class, () -> DatasourceRunner.getJobRunnerInstance().updateDatasourceRunner(jobParameter).run()); + + // Verify + verify(threatIntelLockService, never()).releaseLock(any()); + } + + public void testUpdateDatasourceRunner_whenExceptionAfterAcquiringLock_thenReleaseLock() throws IOException { + ScheduledJobParameter jobParameter = mock(ScheduledJobParameter.class); + when(jobParameter.getName()).thenReturn(ThreatIntelTestHelper.randomLowerCaseString()); + LockModel lockModel = randomLockModel(); + when(threatIntelLockService.acquireLock(jobParameter.getName(), ThreatIntelLockService.LOCK_DURATION_IN_SECONDS)).thenReturn( + Optional.of(lockModel) + ); + when(datasourceDao.getDatasource(jobParameter.getName())).thenThrow(new RuntimeException()); + + // Run + DatasourceRunner.getJobRunnerInstance().updateDatasourceRunner(jobParameter).run(); + + // Verify + verify(threatIntelLockService).releaseLock(any()); + } + + public void testUpdateDatasource_whenDatasourceDoesNotExist_thenDoNothing() throws IOException { + Datasource datasource = new Datasource(); + + // Run + DatasourceRunner.getJobRunnerInstance().updateDatasource(datasource, mock(Runnable.class)); + + // Verify + verify(datasourceUpdateService, never()).deleteUnusedIndices(any()); + } + + public void testUpdateDatasource_whenInvalidState_thenUpdateLastFailedAt() throws IOException { + Datasource datasource = new Datasource(); + datasource.enable(); + datasource.getUpdateStats().setLastFailedAt(null); + datasource.setState(randomStateExcept(DatasourceState.AVAILABLE)); + when(datasourceDao.getDatasource(datasource.getName())).thenReturn(datasource); + + // Run + DatasourceRunner.getJobRunnerInstance().updateDatasource(datasource, mock(Runnable.class)); + + // Verify + assertFalse(datasource.isEnabled()); + assertNotNull(datasource.getUpdateStats().getLastFailedAt()); + verify(datasourceDao).updateDatasource(datasource); + } + + public void testUpdateDatasource_whenValidInput_thenSucceed() throws IOException { + Datasource datasource = randomDatasource(); + datasource.setState(DatasourceState.AVAILABLE); + when(datasourceDao.getDatasource(datasource.getName())).thenReturn(datasource); + Runnable renewLock = mock(Runnable.class); + + // Run + DatasourceRunner.getJobRunnerInstance().updateDatasource(datasource, renewLock); + + // Verify + verify(datasourceUpdateService, times(2)).deleteUnusedIndices(datasource); + verify(datasourceUpdateService).updateOrCreateThreatIntelFeedData(datasource, renewLock); + verify(datasourceUpdateService).updateDatasource(datasource, datasource.getSchedule(), DatasourceTask.ALL); + } + + public void testUpdateDatasource_whenDeleteTask_thenDeleteOnly() throws IOException { + Datasource datasource = randomDatasource(); + datasource.setState(DatasourceState.AVAILABLE); + datasource.setTask(DatasourceTask.DELETE_UNUSED_INDICES); + when(datasourceDao.getDatasource(datasource.getName())).thenReturn(datasource); + Runnable renewLock = mock(Runnable.class); + + // Run + DatasourceRunner.getJobRunnerInstance().updateDatasource(datasource, renewLock); + + // Verify + verify(datasourceUpdateService, times(2)).deleteUnusedIndices(datasource); + verify(datasourceUpdateService, never()).updateOrCreateThreatIntelFeedData(datasource, renewLock); + verify(datasourceUpdateService).updateDatasource(datasource, datasource.getSchedule(), DatasourceTask.ALL); + } + + public void testUpdateDatasourceExceptionHandling() throws IOException { + Datasource datasource = new Datasource(); + datasource.setName(ThreatIntelTestHelper.randomLowerCaseString()); + datasource.getUpdateStats().setLastFailedAt(null); + when(datasourceDao.getDatasource(datasource.getName())).thenReturn(datasource); + doThrow(new RuntimeException("test failure")).when(datasourceUpdateService).deleteUnusedIndices(any()); + + // Run + DatasourceRunner.getJobRunnerInstance().updateDatasource(datasource, mock(Runnable.class)); + + // Verify + assertNotNull(datasource.getUpdateStats().getLastFailedAt()); + verify(datasourceDao).updateDatasource(datasource); + } +} + diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/DatasourceTests.java b/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/DatasourceTests.java new file mode 100644 index 000000000..2a6154b87 --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/DatasourceTests.java @@ -0,0 +1,129 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.jobscheduler; + +import static org.opensearch.securityanalytics.threatIntel.jobscheduler.Datasource.THREAT_INTEL_DATA_INDEX_NAME_PREFIX; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; + +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelTestCase; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelTestHelper; + +public class DatasourceTests extends ThreatIntelTestCase { + + public void testParser_whenAllValueIsFilled_thenSucceed() throws IOException { + String id = ThreatIntelTestHelper.randomLowerCaseString(); + IntervalSchedule schedule = new IntervalSchedule(Instant.now().truncatedTo(ChronoUnit.MILLIS), 1, ChronoUnit.DAYS); +// String endpoint = ThreatIntelTestHelper.randomLowerCaseString(); + List stringList = new ArrayList<>(); + stringList.add("ip"); + + Datasource datasource = new Datasource(id, schedule); + datasource.enable(); + datasource.setCurrentIndex(ThreatIntelTestHelper.randomLowerCaseString()); + datasource.getDatabase().setFields(Arrays.asList("field1", "field2")); + datasource.getDatabase().setFeedId("test123"); + datasource.getDatabase().setFeedName("name"); + datasource.getDatabase().setFeedFormat("csv"); + datasource.getDatabase().setEndpoint("url"); + datasource.getDatabase().setDescription("test description"); + datasource.getDatabase().setOrganization("test org"); + datasource.getDatabase().setContained_iocs_field(stringList); + datasource.getDatabase().setIocCol("0"); + + datasource.getUpdateStats().setLastProcessingTimeInMillis(randomPositiveLong()); + datasource.getUpdateStats().setLastSucceededAt(Instant.now().truncatedTo(ChronoUnit.MILLIS)); + datasource.getUpdateStats().setLastSkippedAt(Instant.now().truncatedTo(ChronoUnit.MILLIS)); + datasource.getUpdateStats().setLastFailedAt(Instant.now().truncatedTo(ChronoUnit.MILLIS)); + + Datasource anotherDatasource = Datasource.PARSER.parse( + createParser(datasource.toXContent(XContentFactory.jsonBuilder(), null)), + null + ); + assertTrue(datasource.equals(anotherDatasource)); + } + + public void testParser_whenNullForOptionalFields_thenSucceed() throws IOException { + String id = ThreatIntelTestHelper.randomLowerCaseString(); + IntervalSchedule schedule = new IntervalSchedule(Instant.now().truncatedTo(ChronoUnit.MILLIS), 1, ChronoUnit.DAYS); +// String endpoint = ThreatIntelTestHelper.randomLowerCaseString(); + Datasource datasource = new Datasource(id, schedule); + Datasource anotherDatasource = Datasource.PARSER.parse( + createParser(datasource.toXContent(XContentFactory.jsonBuilder(), null)), + null + ); + assertTrue(datasource.equals(anotherDatasource)); + } + + public void testCurrentIndexName_whenNotExpired_thenReturnName() { + List stringList = new ArrayList<>(); + stringList.add("ip"); + + String id = ThreatIntelTestHelper.randomLowerCaseString(); + Instant now = Instant.now(); + Datasource datasource = new Datasource(); + datasource.setName(id); + datasource.setCurrentIndex(datasource.newIndexName(ThreatIntelTestHelper.randomLowerCaseString())); + datasource.getDatabase().setFeedId("test123"); + datasource.getDatabase().setFeedName("name"); + datasource.getDatabase().setFeedFormat("csv"); + datasource.getDatabase().setEndpoint("url"); + datasource.getDatabase().setDescription("test description"); + datasource.getDatabase().setOrganization("test org"); + datasource.getDatabase().setContained_iocs_field(stringList); + datasource.getDatabase().setIocCol("0"); + datasource.getDatabase().setFields(new ArrayList<>()); + + assertNotNull(datasource.currentIndexName()); + } + + public void testNewIndexName_whenCalled_thenReturnedExpectedValue() { + String name = ThreatIntelTestHelper.randomLowerCaseString(); + String suffix = ThreatIntelTestHelper.randomLowerCaseString(); + Datasource datasource = new Datasource(); + datasource.setName(name); + assertEquals(String.format(Locale.ROOT, "%s.%s.%s", THREAT_INTEL_DATA_INDEX_NAME_PREFIX, name, suffix), datasource.newIndexName(suffix)); + } + + public void testResetDatabase_whenCalled_thenNullifySomeFields() { + Datasource datasource = randomDatasource(); + assertNotNull(datasource.getDatabase().getFeedId()); + assertNotNull(datasource.getDatabase().getFeedName()); + assertNotNull(datasource.getDatabase().getFeedFormat()); + assertNotNull(datasource.getDatabase().getEndpoint()); + assertNotNull(datasource.getDatabase().getDescription()); + assertNotNull(datasource.getDatabase().getOrganization()); + assertNotNull(datasource.getDatabase().getContained_iocs_field()); + assertNotNull(datasource.getDatabase().getIocCol()); + + // Run + datasource.resetDatabase(); + + // Verify + assertNull(datasource.getDatabase().getFeedId()); + assertNull(datasource.getDatabase().getFeedName()); + assertNull(datasource.getDatabase().getFeedFormat()); + assertNull(datasource.getDatabase().getEndpoint()); + assertNull(datasource.getDatabase().getDescription()); + assertNull(datasource.getDatabase().getOrganization()); + assertNull(datasource.getDatabase().getContained_iocs_field()); + assertNull(datasource.getDatabase().getIocCol()); + } + + public void testLockDurationSeconds() { + Datasource datasource = new Datasource(); + assertNotNull(datasource.getLockDurationSeconds()); + } +} + diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/DatasourceUpdateServiceTests.java b/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/DatasourceUpdateServiceTests.java new file mode 100644 index 000000000..f16f37035 --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/jobscheduler/DatasourceUpdateServiceTests.java @@ -0,0 +1,275 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.securityanalytics.threatIntel.jobscheduler; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +import org.apache.commons.csv.CSVFormat; +import org.apache.commons.csv.CSVParser; +import org.junit.Before; +import org.opensearch.OpenSearchException; +import org.opensearch.cluster.routing.ShardRouting; +import org.opensearch.common.SuppressForbidden; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelFeedParser; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelTestCase; +import org.opensearch.securityanalytics.threatIntel.ThreatIntelTestHelper; +import org.opensearch.securityanalytics.threatIntel.common.DatasourceManifest; +import org.opensearch.securityanalytics.threatIntel.common.DatasourceState; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.Datasource; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.DatasourceTask; +import org.opensearch.securityanalytics.threatIntel.jobscheduler.DatasourceUpdateService; + + +@SuppressForbidden(reason = "unit test") +public class DatasourceUpdateServiceTests extends ThreatIntelTestCase { + private DatasourceUpdateService datasourceUpdateService; + + @Before + public void init() { + datasourceUpdateService = new DatasourceUpdateService(clusterService, datasourceDao, threatIntelFeedDataService); + } + + public void testUpdateOrCreateGeoIpData_whenHashValueIsSame_thenSkipUpdate() throws IOException { + File manifestFile = new File(this.getClass().getClassLoader().getResource("threatIntel/manifest.json").getFile()); + DatasourceManifest manifest = DatasourceManifest.Builder.build(manifestFile.toURI().toURL()); + + Datasource datasource = new Datasource(); + datasource.setState(DatasourceState.AVAILABLE); + datasource.getDatabase().setFeedId(manifest.getFeedId()); + datasource.getDatabase().setFeedName(manifest.getName()); + datasource.getDatabase().setFeedFormat(manifest.getFeedType()); + datasource.getDatabase().setEndpoint(manifest.getUrl()); + datasource.getDatabase().setOrganization(manifest.getOrganization()); + datasource.getDatabase().setDescription(manifest.getDescription()); + datasource.getDatabase().setContained_iocs_field(manifest.getContainedIocs()); + datasource.getDatabase().setIocCol(manifest.getIocCol()); + + datasource.getDatabase().setFields(Arrays.asList("ip", "region")); + + // Run + datasourceUpdateService.updateOrCreateThreatIntelFeedData(datasource, mock(Runnable.class)); + + // Verify + assertNotNull(datasource.getUpdateStats().getLastSkippedAt()); + verify(datasourceDao).updateDatasource(datasource); + } + + public void testUpdateOrCreateGeoIpData_whenInvalidData_thenThrowException() throws IOException { + File manifestFile = new File(this.getClass().getClassLoader().getResource("threatIntel/manifest.json").getFile()); + DatasourceManifest manifest = DatasourceManifest.Builder.build(manifestFile.toURI().toURL()); + + File sampleFile = new File( + this.getClass().getClassLoader().getResource("threatIntel/sample_invalid_less_than_two_fields.csv").getFile() + ); + when(ThreatIntelFeedParser.getThreatIntelFeedReaderCSV(any())).thenReturn(CSVParser.parse(sampleFile, StandardCharsets.UTF_8, CSVFormat.RFC4180)); + + Datasource datasource = new Datasource(); + datasource.setState(DatasourceState.AVAILABLE); + datasource.getDatabase().setFeedId(manifest.getFeedId()); + datasource.getDatabase().setFeedName(manifest.getName()); + datasource.getDatabase().setFeedFormat(manifest.getFeedType()); + datasource.getDatabase().setEndpoint(manifest.getUrl()); + datasource.getDatabase().setOrganization(manifest.getOrganization()); + datasource.getDatabase().setDescription(manifest.getDescription()); + datasource.getDatabase().setContained_iocs_field(manifest.getContainedIocs()); + datasource.getDatabase().setIocCol(manifest.getIocCol()); + + datasource.getDatabase().setFields(Arrays.asList("ip", "region")); + + // Run + expectThrows(OpenSearchException.class, () -> datasourceUpdateService.updateOrCreateThreatIntelFeedData(datasource, mock(Runnable.class))); + } + + public void testUpdateOrCreateGeoIpData_whenIncompatibleFields_thenThrowException() throws IOException { + File manifestFile = new File(this.getClass().getClassLoader().getResource("threatIntel/manifest.json").getFile()); + DatasourceManifest manifest = DatasourceManifest.Builder.build(manifestFile.toURI().toURL()); + + File sampleFile = new File(this.getClass().getClassLoader().getResource("threatIntel/sample_valid.csv").getFile()); + when(ThreatIntelFeedParser.getThreatIntelFeedReaderCSV(any())).thenReturn(CSVParser.parse(sampleFile, StandardCharsets.UTF_8, CSVFormat.RFC4180)); + + Datasource datasource = new Datasource(); + datasource.setState(DatasourceState.AVAILABLE); + datasource.getDatabase().setFeedId(manifest.getFeedId()); + datasource.getDatabase().setFeedName(manifest.getName()); + datasource.getDatabase().setFeedFormat(manifest.getFeedType()); + datasource.getDatabase().setEndpoint(manifest.getUrl()); + datasource.getDatabase().setOrganization(manifest.getOrganization()); + datasource.getDatabase().setDescription(manifest.getDescription()); + datasource.getDatabase().setContained_iocs_field(manifest.getContainedIocs()); + datasource.getDatabase().setIocCol(manifest.getIocCol()); + + datasource.getDatabase().setFields(Arrays.asList("ip", "region")); + + // Run + expectThrows(OpenSearchException.class, () -> datasourceUpdateService.updateOrCreateThreatIntelFeedData(datasource, mock(Runnable.class))); + } + + public void testUpdateOrCreateGeoIpData_whenValidInput_thenSucceed() throws IOException { + File manifestFile = new File(this.getClass().getClassLoader().getResource("threatIntel/manifest.json").getFile()); + DatasourceManifest manifest = DatasourceManifest.Builder.build(manifestFile.toURI().toURL()); + + File sampleFile = new File(this.getClass().getClassLoader().getResource("threatIntel/sample_valid.csv").getFile()); + when(ThreatIntelFeedParser.getThreatIntelFeedReaderCSV(any())).thenReturn(CSVParser.parse(sampleFile, StandardCharsets.UTF_8, CSVFormat.RFC4180)); + ShardRouting shardRouting = mock(ShardRouting.class); + when(shardRouting.started()).thenReturn(true); + when(routingTable.allShards(anyString())).thenReturn(Arrays.asList(shardRouting)); + + Datasource datasource = new Datasource(); + datasource.setState(DatasourceState.AVAILABLE); + datasource.getDatabase().setFeedId(manifest.getFeedId()); + datasource.getDatabase().setFeedName(manifest.getName()); + datasource.getDatabase().setFeedFormat(manifest.getFeedType()); + datasource.getDatabase().setEndpoint(manifest.getUrl()); + datasource.getDatabase().setOrganization(manifest.getOrganization()); + datasource.getDatabase().setDescription(manifest.getDescription()); + datasource.getDatabase().setContained_iocs_field(manifest.getContainedIocs()); + datasource.getDatabase().setIocCol(manifest.getIocCol()); + +// datasource.getDatabase().setFields(Arrays.asList("country_name")); +// datasource.setEndpoint(manifestFile.toURI().toURL().toExternalForm()); + datasource.getUpdateStats().setLastSucceededAt(null); + datasource.getUpdateStats().setLastProcessingTimeInMillis(null); + + // Run + datasourceUpdateService.updateOrCreateThreatIntelFeedData(datasource, mock(Runnable.class)); + + // Verify + assertEquals(manifest.getFeedId(), datasource.getDatabase().getFeedId()); + assertEquals(manifest.getName(), datasource.getDatabase().getFeedName()); + assertEquals(manifest.getFeedType(), datasource.getDatabase().getFeedFormat()); + assertEquals(manifest.getUrl(), datasource.getDatabase().getEndpoint()); + assertEquals(manifest.getOrganization(), datasource.getDatabase().getOrganization()); + assertEquals(manifest.getDescription(), datasource.getDatabase().getDescription()); + assertEquals(manifest.getOrganization(), datasource.getDatabase().getOrganization()); + assertEquals(manifest.getContainedIocs(), datasource.getDatabase().getContained_iocs_field()); + assertEquals(manifest.getIocCol(), datasource.getDatabase().getIocCol()); + + assertNotNull(datasource.getUpdateStats().getLastSucceededAt()); + assertNotNull(datasource.getUpdateStats().getLastProcessingTimeInMillis()); + verify(datasourceDao, times(2)).updateDatasource(datasource); + verify(threatIntelFeedDataService).saveThreatIntelFeedDataCSV(eq(datasource.currentIndexName()), isA(String[].class), any(Iterator.class), any(Runnable.class), manifest); + } + + public void testWaitUntilAllShardsStarted_whenTimedOut_thenThrowException() { + String indexName = ThreatIntelTestHelper.randomLowerCaseString(); + ShardRouting shardRouting = mock(ShardRouting.class); + when(shardRouting.started()).thenReturn(false); + when(routingTable.allShards(indexName)).thenReturn(Arrays.asList(shardRouting)); + + // Run + Exception e = expectThrows(OpenSearchException.class, () -> datasourceUpdateService.waitUntilAllShardsStarted(indexName, 10)); + + // Verify + assertTrue(e.getMessage().contains("did not complete")); + } + + public void testWaitUntilAllShardsStarted_whenInterrupted_thenThrowException() { + String indexName = ThreatIntelTestHelper.randomLowerCaseString(); + ShardRouting shardRouting = mock(ShardRouting.class); + when(shardRouting.started()).thenReturn(false); + when(routingTable.allShards(indexName)).thenReturn(Arrays.asList(shardRouting)); + + // Run + Thread.currentThread().interrupt(); + Exception e = expectThrows(RuntimeException.class, () -> datasourceUpdateService.waitUntilAllShardsStarted(indexName, 10)); + + // Verify + assertEquals(InterruptedException.class, e.getCause().getClass()); + } + + public void testGetHeaderFields_whenValidInput_thenReturnCorrectValue() throws IOException { + File manifestFile = new File(this.getClass().getClassLoader().getResource("threatIntel/manifest.json").getFile()); + + File sampleFile = new File(this.getClass().getClassLoader().getResource("threatIntel/sample_valid.csv").getFile()); + when(ThreatIntelFeedParser.getThreatIntelFeedReaderCSV(any())).thenReturn(CSVParser.parse(sampleFile, StandardCharsets.UTF_8, CSVFormat.RFC4180)); + + // Run + assertEquals(Arrays.asList("country_name"), datasourceUpdateService.getHeaderFields(manifestFile.toURI().toURL().toExternalForm())); + } + + public void testDeleteUnusedIndices_whenValidInput_thenSucceed() { + String datasourceName = ThreatIntelTestHelper.randomLowerCaseString(); + String indexPrefix = String.format(".threatintel-data.%s.", datasourceName); + Instant now = Instant.now(); + String currentIndex = indexPrefix + now.toEpochMilli(); + String oldIndex = indexPrefix + now.minusMillis(1).toEpochMilli(); + String lingeringIndex = indexPrefix + now.minusMillis(2).toEpochMilli(); + Datasource datasource = new Datasource(); + datasource.setName(datasourceName); + datasource.setCurrentIndex(currentIndex); + datasource.getIndices().add(currentIndex); + datasource.getIndices().add(oldIndex); + datasource.getIndices().add(lingeringIndex); + + when(metadata.hasIndex(currentIndex)).thenReturn(true); + when(metadata.hasIndex(oldIndex)).thenReturn(true); + when(metadata.hasIndex(lingeringIndex)).thenReturn(false); + + datasourceUpdateService.deleteUnusedIndices(datasource); + + assertEquals(1, datasource.getIndices().size()); + assertEquals(currentIndex, datasource.getIndices().get(0)); + verify(datasourceDao).updateDatasource(datasource); + verify(threatIntelFeedDataService).deleteThreatIntelDataIndex(oldIndex); + } + + public void testUpdateDatasource_whenNoChange_thenNoUpdate() { + Datasource datasource = randomDatasource(); + + // Run + datasourceUpdateService.updateDatasource(datasource, datasource.getSchedule(), datasource.getTask()); + + // Verify + verify(datasourceDao, never()).updateDatasource(any()); + } + + public void testUpdateDatasource_whenChange_thenUpdate() { + Datasource datasource = randomDatasource(); + datasource.setTask(DatasourceTask.ALL); + + // Run + datasourceUpdateService.updateDatasource( + datasource, + new IntervalSchedule(Instant.now(), datasource.getSchedule().getInterval() + 1, ChronoUnit.DAYS), + datasource.getTask() + ); + datasourceUpdateService.updateDatasource(datasource, datasource.getSchedule(), DatasourceTask.DELETE_UNUSED_INDICES); + + // Verify + verify(datasourceDao, times(2)).updateDatasource(any()); + } + + public void testGetHeaderFields_whenValidInput_thenSucceed() throws IOException { + File manifestFile = new File(this.getClass().getClassLoader().getResource("threatIntel/manifest.json").getFile()); + File sampleFile = new File(this.getClass().getClassLoader().getResource("threatIntel/sample_valid.csv").getFile()); + when(ThreatIntelFeedParser.getThreatIntelFeedReaderCSV(any())).thenReturn(CSVParser.parse(sampleFile, StandardCharsets.UTF_8, CSVFormat.RFC4180)); + + // Run + List fields = datasourceUpdateService.getHeaderFields(manifestFile.toURI().toURL().toExternalForm()); + + // Verify + List expectedFields = Arrays.asList("country_name"); + assertEquals(expectedFields, fields); + } +} diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/sample_invalid_less_than_two_fields.csv b/src/test/java/org/opensearch/securityanalytics/threatIntel/sample_invalid_less_than_two_fields.csv new file mode 100644 index 000000000..08670061c --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/sample_invalid_less_than_two_fields.csv @@ -0,0 +1,2 @@ +network +1.0.0.0/24 \ No newline at end of file diff --git a/src/test/java/org/opensearch/securityanalytics/threatIntel/sample_valid.csv b/src/test/java/org/opensearch/securityanalytics/threatIntel/sample_valid.csv new file mode 100644 index 000000000..fad1eb6fd --- /dev/null +++ b/src/test/java/org/opensearch/securityanalytics/threatIntel/sample_valid.csv @@ -0,0 +1,3 @@ +ip,region +1.0.0.0/24,Australia +10.0.0.0/24,USA \ No newline at end of file