Skip to content

Commit

Permalink
using proto message envelope for kafka transport
Browse files Browse the repository at this point in the history
  • Loading branch information
ag060 committed Dec 30, 2024
1 parent afe3a65 commit 5a75af5
Show file tree
Hide file tree
Showing 10 changed files with 249 additions and 251 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ https:
**/data-zoo-data
**/data-zoo-logs
**/bin
.factorypath

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package com.akto.threat.detection.kafka;

import com.akto.kafka.KafkaConfig;
import com.google.protobuf.Message;
import java.time.Duration;
import java.util.Properties;
import org.apache.kafka.clients.producer.*;
import org.apache.kafka.common.serialization.StringSerializer;

public class KafkaProtoProducer {
private KafkaProducer<String, byte[]> producer;
public boolean producerReady;

public KafkaProtoProducer(KafkaConfig kafkaConfig) {
this.producer = generateProducer(
kafkaConfig.getBootstrapServers(),
kafkaConfig.getProducerConfig().getLingerMs(),
kafkaConfig.getProducerConfig().getBatchSize());
}

public void send(String topic, Message message) {
byte[] messageBytes = message.toByteArray();
this.producer.send(new ProducerRecord<>(topic, messageBytes));
}

public void close() {
this.producerReady = false;
producer.close(Duration.ofMillis(0)); // close immediately
}

private KafkaProducer<String, byte[]> generateProducer(String brokerIP, int lingerMS, int batchSize) {
if (producer != null)
close(); // close existing producer connection

int requestTimeoutMs = 5000;
Properties kafkaProps = new Properties();
kafkaProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerIP);
kafkaProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG,
"org.apache.kafka.common.serialization.ByteArraySerializer");
kafkaProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class);
kafkaProps.put(ProducerConfig.BATCH_SIZE_CONFIG, batchSize);
kafkaProps.put(ProducerConfig.LINGER_MS_CONFIG, lingerMS);
kafkaProps.put(ProducerConfig.RETRIES_CONFIG, 0);
kafkaProps.put(ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG, requestTimeoutMs);
kafkaProps.put(ProducerConfig.DELIVERY_TIMEOUT_MS_CONFIG, lingerMS + requestTimeoutMs);
return new KafkaProducer<String, byte[]>(kafkaProps);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
import org.apache.kafka.clients.consumer.ConsumerRecords;
import org.apache.kafka.clients.consumer.KafkaConsumer;

public abstract class AbstractKafkaConsumerTask implements Task {
public abstract class AbstractKafkaConsumerTask<V> implements Task {

protected Consumer<String, String> kafkaConsumer;
protected Consumer<String, V> kafkaConsumer;
protected KafkaConfig kafkaConfig;
protected String kafkaTopic;

Expand All @@ -24,9 +24,8 @@ public AbstractKafkaConsumerTask(KafkaConfig kafkaConfig, String kafkaTopic) {
String kafkaBrokerUrl = kafkaConfig.getBootstrapServers();
String groupId = kafkaConfig.getGroupId();

Properties properties =
Utils.configProperties(
kafkaBrokerUrl, groupId, kafkaConfig.getConsumerConfig().getMaxPollRecords());
Properties properties = Utils.configProperties(
kafkaBrokerUrl, groupId, kafkaConfig.getConsumerConfig().getMaxPollRecords());
this.kafkaConsumer = new KafkaConsumer<>(properties);
}

Expand All @@ -40,9 +39,8 @@ public void run() {
() -> {
// Poll data from Kafka topic
while (true) {
ConsumerRecords<String, String> records =
kafkaConsumer.poll(
Duration.ofMillis(kafkaConfig.getConsumerConfig().getPollDurationMilli()));
ConsumerRecords<String, V> records = kafkaConsumer.poll(
Duration.ofMillis(kafkaConfig.getConsumerConfig().getPollDurationMilli()));
if (records.isEmpty()) {
continue;
}
Expand All @@ -60,5 +58,5 @@ public void run() {
});
}

abstract void processRecords(ConsumerRecords<String, String> records);
abstract void processRecords(ConsumerRecords<String, V> records);
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.akto.dto.type.URLMethods;
import com.akto.kafka.KafkaConfig;
import com.akto.proto.generated.threat_detection.message.sample_request.v1.SampleMaliciousRequest;
import com.akto.proto.generated.threat_detection.message.sample_request.v1.SampleRequestKafkaEnvelope;
import com.akto.threat.detection.db.entity.MaliciousEventEntity;
import com.akto.threat.detection.dto.MessageEnvelope;
import com.google.protobuf.InvalidProtocolBufferException;
Expand All @@ -17,7 +18,7 @@
/*
This will read sample malicious data from kafka topic and save it to DB.
*/
public class FlushSampleDataTask extends AbstractKafkaConsumerTask {
public class FlushSampleDataTask extends AbstractKafkaConsumerTask<byte[]> {

private final SessionFactory sessionFactory;

Expand All @@ -27,37 +28,29 @@ public FlushSampleDataTask(
this.sessionFactory = sessionFactory;
}

protected void processRecords(ConsumerRecords<String, String> records) {
protected void processRecords(ConsumerRecords<String, byte[]> records) {
List<MaliciousEventEntity> events = new ArrayList<>();
records.forEach(
r -> {
String message = r.value();
SampleMaliciousRequest.Builder builder = SampleMaliciousRequest.newBuilder();
MessageEnvelope m = MessageEnvelope.unmarshal(message).orElse(null);
if (m == null) {
return;
}

SampleRequestKafkaEnvelope envelope;
try {
JsonFormat.parser().merge(m.getData(), builder);
envelope = SampleRequestKafkaEnvelope.parseFrom(r.value());
SampleMaliciousRequest evt = envelope.getMaliciousRequest();

events.add(
MaliciousEventEntity.newBuilder()
.setActor(envelope.getActor())
.setFilterId(evt.getFilterId())
.setUrl(evt.getUrl())
.setMethod(URLMethods.Method.fromString(evt.getMethod()))
.setTimestamp(evt.getTimestamp())
.setOrig(evt.getPayload())
.setApiCollectionId(evt.getApiCollectionId())
.setIp(evt.getIp())
.build());
} catch (InvalidProtocolBufferException e) {
e.printStackTrace();
return;
}

SampleMaliciousRequest evt = builder.build();

events.add(
MaliciousEventEntity.newBuilder()
.setActor(m.getActor())
.setFilterId(evt.getFilterId())
.setUrl(evt.getUrl())
.setMethod(URLMethods.Method.fromString(evt.getMethod()))
.setTimestamp(evt.getTimestamp())
.setOrig(evt.getPayload())
.setApiCollectionId(evt.getApiCollectionId())
.setIp(evt.getIp())
.build());
});

Session session = this.sessionFactory.openSession();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import com.akto.kafka.Kafka;
import com.akto.kafka.KafkaConfig;
import com.akto.proto.generated.threat_detection.message.malicious_event.event_type.v1.EventType;
import com.akto.proto.generated.threat_detection.message.malicious_event.v1.MaliciousEventKafkaEnvelope;
import com.akto.proto.generated.threat_detection.message.malicious_event.v1.MaliciousEventMessage;
import com.akto.proto.generated.threat_detection.message.sample_request.v1.SampleMaliciousRequest;
import com.akto.proto.generated.threat_detection.message.sample_request.v1.SampleRequestKafkaEnvelope;
import com.akto.rules.TestPlugin;
import com.akto.runtime.utils.Utils;
import com.akto.test_editor.execution.VariableResolver;
Expand All @@ -27,6 +29,7 @@
import com.akto.threat.detection.cache.RedisBackedCounterCache;
import com.akto.threat.detection.constants.KafkaTopic;
import com.akto.threat.detection.dto.MessageEnvelope;
import com.akto.threat.detection.kafka.KafkaProtoProducer;
import com.akto.threat.detection.smart_event_detector.window_based.WindowBasedThresholdNotifier;
import com.google.protobuf.InvalidProtocolBufferException;
import io.lettuce.core.RedisClient;
Expand Down Expand Up @@ -54,7 +57,7 @@ public class MaliciousTrafficDetectorTask implements Task {
private int filterLastUpdatedAt = 0;
private int filterUpdateIntervalSec = 300;

private final Kafka internalKafka;
private final KafkaProtoProducer internalKafka;

private static final DataActor dataActor = DataActorFactory.fetchInstance();

Expand All @@ -65,23 +68,17 @@ public MaliciousTrafficDetectorTask(
String kafkaBrokerUrl = trafficConfig.getBootstrapServers();
String groupId = trafficConfig.getGroupId();

this.kafkaConsumer =
new KafkaConsumer<>(
Utils.configProperties(
kafkaBrokerUrl, groupId, trafficConfig.getConsumerConfig().getMaxPollRecords()));
this.kafkaConsumer = new KafkaConsumer<>(
Utils.configProperties(
kafkaBrokerUrl, groupId, trafficConfig.getConsumerConfig().getMaxPollRecords()));

this.httpCallParser = new HttpCallParser(120, 1000);

this.windowBasedThresholdNotifier =
new WindowBasedThresholdNotifier(
new RedisBackedCounterCache(redisClient, "wbt"),
new WindowBasedThresholdNotifier.Config(100, 10 * 60));
this.windowBasedThresholdNotifier = new WindowBasedThresholdNotifier(
new RedisBackedCounterCache(redisClient, "wbt"),
new WindowBasedThresholdNotifier.Config(100, 10 * 60));

this.internalKafka =
new Kafka(
internalConfig.getBootstrapServers(),
internalConfig.getProducerConfig().getLingerMs(),
internalConfig.getProducerConfig().getBatchSize());
this.internalKafka = new KafkaProtoProducer(internalConfig);
}

public void run() {
Expand All @@ -91,9 +88,8 @@ public void run() {
() -> {
// Poll data from Kafka topic
while (true) {
ConsumerRecords<String, String> records =
kafkaConsumer.poll(
Duration.ofMillis(kafkaConfig.getConsumerConfig().getPollDurationMilli()));
ConsumerRecords<String, String> records = kafkaConsumer.poll(
Duration.ofMillis(kafkaConfig.getConsumerConfig().getPollDurationMilli()));

try {
for (ConsumerRecord<String, String> record : records) {
Expand Down Expand Up @@ -133,8 +129,7 @@ private boolean validateFilterForRequest(
int apiCollectionId = httpCallParser.createApiCollectionId(responseParam);
responseParam.requestParams.setApiCollectionId(apiCollectionId);
String url = responseParam.getRequestParams().getURL();
URLMethods.Method method =
URLMethods.Method.fromString(responseParam.getRequestParams().getMethod());
URLMethods.Method method = URLMethods.Method.fromString(responseParam.getRequestParams().getMethod());
ApiInfo.ApiInfoKey apiInfoKey = new ApiInfo.ApiInfoKey(apiCollectionId, url, method);
Map<String, Object> varMap = apiFilter.resolveVarMap();
VariableResolver.resolveWordList(
Expand All @@ -146,9 +141,8 @@ private boolean validateFilterForRequest(
},
apiInfoKey);
String filterExecutionLogId = UUID.randomUUID().toString();
ValidationResult res =
TestPlugin.validateFilter(
apiFilter.getFilter().getNode(), rawApi, apiInfoKey, varMap, filterExecutionLogId);
ValidationResult res = TestPlugin.validateFilter(
apiFilter.getFilter().getNode(), rawApi, apiInfoKey, varMap, filterExecutionLogId);

return res.getIsValid();
} catch (Exception e) {
Expand All @@ -168,7 +162,7 @@ private void processRecord(ConsumerRecord<String, String> record) {
return;
}

List<MessageEnvelope> maliciousMessages = new ArrayList<>();
List<SampleRequestKafkaEnvelope> maliciousMessages = new ArrayList<>();

System.out.println("Total number of filters: " + filters.size());

Expand Down Expand Up @@ -200,24 +194,22 @@ private void processRecord(ConsumerRecord<String, String> record) {
String groupKey = apiFilter.getId();
String aggKey = actor + "|" + groupKey;

SampleMaliciousRequest maliciousReq =
SampleMaliciousRequest.newBuilder()
.setUrl(responseParam.getRequestParams().getURL())
.setMethod(responseParam.getRequestParams().getMethod())
.setPayload(responseParam.getOrig())
.setIp(actor) // For now using actor as IP
.setApiCollectionId(responseParam.getRequestParams().getApiCollectionId())
.setTimestamp(responseParam.getTime())
.setFilterId(apiFilter.getId())
.build();

try {
maliciousMessages.add(
MessageEnvelope.generateEnvelope(
responseParam.getAccountId(), actor, maliciousReq));
} catch (InvalidProtocolBufferException e) {
return;
}
SampleMaliciousRequest maliciousReq = SampleMaliciousRequest.newBuilder()
.setUrl(responseParam.getRequestParams().getURL())
.setMethod(responseParam.getRequestParams().getMethod())
.setPayload(responseParam.getOrig())
.setIp(actor) // For now using actor as IP
.setApiCollectionId(responseParam.getRequestParams().getApiCollectionId())
.setTimestamp(responseParam.getTime())
.setFilterId(apiFilter.getId())
.build();

maliciousMessages.add(
SampleRequestKafkaEnvelope.newBuilder()
.setActor(actor)
.setAccountId(responseParam.getAccountId())
.setMaliciousRequest(maliciousReq)
.build());

if (!isAggFilter) {
generateAndPushMaliciousEventRequest(
Expand All @@ -227,8 +219,8 @@ private void processRecord(ConsumerRecord<String, String> record) {

// Aggregation rules
for (Rule rule : aggRules.getRule()) {
WindowBasedThresholdNotifier.Result result =
this.windowBasedThresholdNotifier.shouldNotify(aggKey, maliciousReq, rule);
WindowBasedThresholdNotifier.Result result = this.windowBasedThresholdNotifier.shouldNotify(aggKey,
maliciousReq, rule);

if (result.shouldNotify()) {
System.out.print("Notifying for aggregation rule: " + rule);
Expand All @@ -249,12 +241,7 @@ private void processRecord(ConsumerRecord<String, String> record) {
try {
maliciousMessages.forEach(
sample -> {
sample
.marshal()
.ifPresent(
data -> {
internalKafka.send(data, KafkaTopic.ThreatDetection.MALICIOUS_EVENTS);
});
internalKafka.send(KafkaTopic.ThreatDetection.MALICIOUS_EVENTS, sample);
});
} catch (Exception e) {
e.printStackTrace();
Expand All @@ -267,25 +254,27 @@ private void generateAndPushMaliciousEventRequest(
HttpResponseParams responseParam,
SampleMaliciousRequest maliciousReq,
EventType eventType) {
MaliciousEventMessage maliciousEvent =
MaliciousEventMessage.newBuilder()
.setFilterId(apiFilter.getId())
.setActor(actor)
.setDetectedAt(responseParam.getTime())
.setEventType(eventType)
.setLatestApiCollectionId(maliciousReq.getApiCollectionId())
.setLatestApiIp(maliciousReq.getIp())
.setLatestApiPayload(maliciousReq.getPayload())
.setLatestApiMethod(maliciousReq.getMethod())
.setDetectedAt(responseParam.getTime())
.build();
MaliciousEventMessage maliciousEvent = MaliciousEventMessage.newBuilder()
.setFilterId(apiFilter.getId())
.setActor(actor)
.setDetectedAt(responseParam.getTime())
.setEventType(eventType)
.setLatestApiCollectionId(maliciousReq.getApiCollectionId())
.setLatestApiIp(maliciousReq.getIp())
.setLatestApiPayload(maliciousReq.getPayload())
.setLatestApiMethod(maliciousReq.getMethod())
.setDetectedAt(responseParam.getTime())
.build();
try {
System.out.println("Pushing malicious event to kafka: " + maliciousEvent);
MaliciousEventKafkaEnvelope envelope = MaliciousEventKafkaEnvelope.newBuilder().setActor(actor)
.setAccountId(responseParam.getAccountId())
.setMaliciousEvent(maliciousEvent).build();
MessageEnvelope.generateEnvelope(responseParam.getAccountId(), actor, maliciousEvent)
.marshal()
.ifPresent(
data -> {
internalKafka.send(data, KafkaTopic.ThreatDetection.ALERTS);
internalKafka.send(KafkaTopic.ThreatDetection.ALERTS, envelope);
});
} catch (InvalidProtocolBufferException e) {
e.printStackTrace();
Expand Down
Loading

0 comments on commit 5a75af5

Please sign in to comment.