Skip to content

Commit

Permalink
using proto message envelope for kafka transport in threat detection …
Browse files Browse the repository at this point in the history
…client (#1874)
  • Loading branch information
ag060 committed Jan 7, 2025
1 parent 84928bf commit 7725089
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 220 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package com.akto.threat.detection.kafka;

import com.akto.kafka.KafkaConfig;
import com.google.protobuf.Message;
import java.time.Duration;
import java.util.Properties;
import org.apache.kafka.clients.producer.*;
import org.apache.kafka.common.serialization.StringSerializer;

public class KafkaProtoProducer {
private KafkaProducer<String, byte[]> producer;
public boolean producerReady;

public KafkaProtoProducer(KafkaConfig kafkaConfig) {
this.producer = generateProducer(
kafkaConfig.getBootstrapServers(),
kafkaConfig.getProducerConfig().getLingerMs(),
kafkaConfig.getProducerConfig().getBatchSize());
}

public void send(String topic, Message message) {
byte[] messageBytes = message.toByteArray();
this.producer.send(new ProducerRecord<>(topic, messageBytes));
}

public void close() {
this.producerReady = false;
producer.close(Duration.ofMillis(0)); // close immediately
}

private KafkaProducer<String, byte[]> generateProducer(String brokerIP, int lingerMS, int batchSize) {
if (producer != null)
close(); // close existing producer connection

int requestTimeoutMs = 5000;
Properties kafkaProps = new Properties();
kafkaProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerIP);
kafkaProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG,
"org.apache.kafka.common.serialization.ByteArraySerializer");
kafkaProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class);
kafkaProps.put(ProducerConfig.BATCH_SIZE_CONFIG, batchSize);
kafkaProps.put(ProducerConfig.LINGER_MS_CONFIG, lingerMS);
kafkaProps.put(ProducerConfig.RETRIES_CONFIG, 0);
kafkaProps.put(ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG, requestTimeoutMs);
kafkaProps.put(ProducerConfig.DELIVERY_TIMEOUT_MS_CONFIG, lingerMS + requestTimeoutMs);
return new KafkaProducer<String, byte[]>(kafkaProps);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
import org.apache.kafka.clients.consumer.Consumer;
import org.apache.kafka.clients.consumer.ConsumerRecords;
import org.apache.kafka.clients.consumer.KafkaConsumer;
import org.apache.kafka.clients.consumer.ConsumerConfig;
import org.apache.kafka.common.serialization.StringDeserializer;

public abstract class AbstractKafkaConsumerTask implements Task {
public abstract class AbstractKafkaConsumerTask<V> implements Task {

protected Consumer<String, String> kafkaConsumer;
protected Consumer<String, V> kafkaConsumer;
protected KafkaConfig kafkaConfig;
protected String kafkaTopic;

Expand All @@ -24,9 +26,16 @@ public AbstractKafkaConsumerTask(KafkaConfig kafkaConfig, String kafkaTopic) {
String kafkaBrokerUrl = kafkaConfig.getBootstrapServers();
String groupId = kafkaConfig.getGroupId();

Properties properties =
Utils.configProperties(
kafkaBrokerUrl, groupId, kafkaConfig.getConsumerConfig().getMaxPollRecords());
Properties properties = new Properties();
properties.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, kafkaBrokerUrl);
properties.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class);
properties.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG,
"org.apache.kafka.common.serialization.ByteArrayDeserializer");
properties.put(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, kafkaConfig.getConsumerConfig().getMaxPollRecords());
properties.put(ConsumerConfig.GROUP_ID_CONFIG, kafkaConfig.getGroupId());
properties.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest");
properties.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false);

this.kafkaConsumer = new KafkaConsumer<>(properties);
}

Expand All @@ -40,9 +49,8 @@ public void run() {
() -> {
// Poll data from Kafka topic
while (true) {
ConsumerRecords<String, String> records =
kafkaConsumer.poll(
Duration.ofMillis(kafkaConfig.getConsumerConfig().getPollDurationMilli()));
ConsumerRecords<String, V> records = kafkaConsumer.poll(
Duration.ofMillis(kafkaConfig.getConsumerConfig().getPollDurationMilli()));
if (records.isEmpty()) {
continue;
}
Expand All @@ -60,5 +68,5 @@ public void run() {
});
}

abstract void processRecords(ConsumerRecords<String, String> records);
abstract void processRecords(ConsumerRecords<String, V> records);
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.akto.dto.type.URLMethods;
import com.akto.kafka.KafkaConfig;
import com.akto.proto.generated.threat_detection.message.sample_request.v1.SampleMaliciousRequest;
import com.akto.proto.generated.threat_detection.message.sample_request.v1.SampleRequestKafkaEnvelope;
import com.akto.threat.detection.db.entity.MaliciousEventEntity;
import com.akto.threat.detection.dto.MessageEnvelope;
import com.google.protobuf.InvalidProtocolBufferException;
Expand All @@ -17,7 +18,7 @@
/*
This will read sample malicious data from kafka topic and save it to DB.
*/
public class FlushSampleDataTask extends AbstractKafkaConsumerTask {
public class FlushSampleDataTask extends AbstractKafkaConsumerTask<byte[]> {

private final SessionFactory sessionFactory;

Expand All @@ -27,37 +28,29 @@ public FlushSampleDataTask(
this.sessionFactory = sessionFactory;
}

protected void processRecords(ConsumerRecords<String, String> records) {
protected void processRecords(ConsumerRecords<String, byte[]> records) {
List<MaliciousEventEntity> events = new ArrayList<>();
records.forEach(
r -> {
String message = r.value();
SampleMaliciousRequest.Builder builder = SampleMaliciousRequest.newBuilder();
MessageEnvelope m = MessageEnvelope.unmarshal(message).orElse(null);
if (m == null) {
return;
}

SampleRequestKafkaEnvelope envelope;
try {
JsonFormat.parser().merge(m.getData(), builder);
envelope = SampleRequestKafkaEnvelope.parseFrom(r.value());
SampleMaliciousRequest evt = envelope.getMaliciousRequest();

events.add(
MaliciousEventEntity.newBuilder()
.setActor(envelope.getActor())
.setFilterId(evt.getFilterId())
.setUrl(evt.getUrl())
.setMethod(URLMethods.Method.fromString(evt.getMethod()))
.setTimestamp(evt.getTimestamp())
.setOrig(evt.getPayload())
.setApiCollectionId(evt.getApiCollectionId())
.setIp(evt.getIp())
.build());
} catch (InvalidProtocolBufferException e) {
e.printStackTrace();
return;
}

SampleMaliciousRequest evt = builder.build();

events.add(
MaliciousEventEntity.newBuilder()
.setActor(m.getActor())
.setFilterId(evt.getFilterId())
.setUrl(evt.getUrl())
.setMethod(URLMethods.Method.fromString(evt.getMethod()))
.setTimestamp(evt.getTimestamp())
.setOrig(evt.getPayload())
.setApiCollectionId(evt.getApiCollectionId())
.setIp(evt.getIp())
.build());
});

Session session = this.sessionFactory.openSession();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
import com.akto.dto.test_editor.YamlTemplate;
import com.akto.dto.type.URLMethods;
import com.akto.hybrid_parsers.HttpCallParser;
import com.akto.kafka.Kafka;
import com.akto.kafka.KafkaConfig;
import com.akto.proto.generated.threat_detection.message.malicious_event.event_type.v1.EventType;
import com.akto.proto.generated.threat_detection.message.malicious_event.v1.MaliciousEventKafkaEnvelope;
import com.akto.proto.generated.threat_detection.message.malicious_event.v1.MaliciousEventMessage;
import com.akto.proto.generated.threat_detection.message.sample_request.v1.SampleMaliciousRequest;
import com.akto.proto.generated.threat_detection.message.sample_request.v1.SampleRequestKafkaEnvelope;
import com.akto.rules.TestPlugin;
import com.akto.runtime.utils.Utils;
import com.akto.test_editor.execution.VariableResolver;
Expand All @@ -27,6 +28,7 @@
import com.akto.threat.detection.cache.RedisBackedCounterCache;
import com.akto.threat.detection.constants.KafkaTopic;
import com.akto.threat.detection.dto.MessageEnvelope;
import com.akto.threat.detection.kafka.KafkaProtoProducer;
import com.akto.threat.detection.smart_event_detector.window_based.WindowBasedThresholdNotifier;
import com.google.protobuf.InvalidProtocolBufferException;
import io.lettuce.core.RedisClient;
Expand Down Expand Up @@ -54,7 +56,7 @@ public class MaliciousTrafficDetectorTask implements Task {
private int filterLastUpdatedAt = 0;
private int filterUpdateIntervalSec = 300;

private final Kafka internalKafka;
private final KafkaProtoProducer internalKafka;

private static final DataActor dataActor = DataActorFactory.fetchInstance();

Expand All @@ -77,11 +79,7 @@ public MaliciousTrafficDetectorTask(
new RedisBackedCounterCache(redisClient, "wbt"),
new WindowBasedThresholdNotifier.Config(100, 10 * 60));

this.internalKafka =
new Kafka(
internalConfig.getBootstrapServers(),
internalConfig.getProducerConfig().getLingerMs(),
internalConfig.getProducerConfig().getBatchSize());
this.internalKafka = new KafkaProtoProducer(internalConfig);
}

public void run() {
Expand Down Expand Up @@ -123,19 +121,20 @@ private Map<String, FilterConfig> getFilters() {
return apiFilters;
}

private boolean validateFilterForRequest(FilterConfig apiFilter, RawApi rawApi, ApiInfo.ApiInfoKey apiInfoKey, String message) {
private boolean validateFilterForRequest(
FilterConfig apiFilter, RawApi rawApi, ApiInfo.ApiInfoKey apiInfoKey, String message) {
try {
System.out.println("using buildFromMessageNew func");

Map<String, Object> varMap = apiFilter.resolveVarMap();
VariableResolver.resolveWordList(
varMap,
new HashMap<ApiInfo.ApiInfoKey, List<String>>() {
{
put(apiInfoKey, Collections.singletonList(message));
}
},
apiInfoKey);
varMap,
new HashMap<ApiInfo.ApiInfoKey, List<String>>() {
{
put(apiInfoKey, Collections.singletonList(message));
}
},
apiInfoKey);

String filterExecutionLogId = UUID.randomUUID().toString();
ValidationResult res =
Expand All @@ -160,7 +159,7 @@ private void processRecord(ConsumerRecord<String, String> record) {
return;
}

List<MessageEnvelope> maliciousMessages = new ArrayList<>();
List<SampleRequestKafkaEnvelope> maliciousMessages = new ArrayList<>();

System.out.println("Total number of filters: " + filters.size());

Expand Down Expand Up @@ -212,13 +211,12 @@ private void processRecord(ConsumerRecord<String, String> record) {
.setFilterId(apiFilter.getId())
.build();

try {
maliciousMessages.add(
MessageEnvelope.generateEnvelope(
responseParam.getAccountId(), actor, maliciousReq));
} catch (InvalidProtocolBufferException e) {
return;
}
maliciousMessages.add(
SampleRequestKafkaEnvelope.newBuilder()
.setActor(actor)
.setAccountId(responseParam.getAccountId())
.setMaliciousRequest(maliciousReq)
.build());

if (!isAggFilter) {
generateAndPushMaliciousEventRequest(
Expand Down Expand Up @@ -250,12 +248,7 @@ private void processRecord(ConsumerRecord<String, String> record) {
try {
maliciousMessages.forEach(
sample -> {
sample
.marshal()
.ifPresent(
data -> {
internalKafka.send(data, KafkaTopic.ThreatDetection.MALICIOUS_EVENTS);
});
internalKafka.send(KafkaTopic.ThreatDetection.MALICIOUS_EVENTS, sample);
});
} catch (Exception e) {
e.printStackTrace();
Expand All @@ -281,12 +274,18 @@ private void generateAndPushMaliciousEventRequest(
.setDetectedAt(responseParam.getTime())
.build();
try {
System.out.println("Pushing malicious event to kafka: ");
System.out.println("Pushing malicious event to kafka: " + maliciousEvent);
MaliciousEventKafkaEnvelope envelope =
MaliciousEventKafkaEnvelope.newBuilder()
.setActor(actor)
.setAccountId(responseParam.getAccountId())
.setMaliciousEvent(maliciousEvent)
.build();
MessageEnvelope.generateEnvelope(responseParam.getAccountId(), actor, maliciousEvent)
.marshal()
.ifPresent(
data -> {
internalKafka.send(data, KafkaTopic.ThreatDetection.ALERTS);
internalKafka.send(KafkaTopic.ThreatDetection.ALERTS, envelope);
});
} catch (InvalidProtocolBufferException e) {
e.printStackTrace();
Expand Down
Loading

0 comments on commit 7725089

Please sign in to comment.