Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

using proto message envelope for kafka transport #1874

Merged
merged 1 commit into from
Jan 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ https:
**/data-zoo-data
**/data-zoo-logs
**/bin
.factorypath

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package com.akto.threat.detection.kafka;

import com.akto.kafka.KafkaConfig;
import com.google.protobuf.Message;
import java.time.Duration;
import java.util.Properties;
import org.apache.kafka.clients.producer.*;
import org.apache.kafka.common.serialization.StringSerializer;

public class KafkaProtoProducer {
private KafkaProducer<String, byte[]> producer;
public boolean producerReady;

public KafkaProtoProducer(KafkaConfig kafkaConfig) {
this.producer = generateProducer(
kafkaConfig.getBootstrapServers(),
kafkaConfig.getProducerConfig().getLingerMs(),
kafkaConfig.getProducerConfig().getBatchSize());
}

public void send(String topic, Message message) {
byte[] messageBytes = message.toByteArray();
this.producer.send(new ProducerRecord<>(topic, messageBytes));
}

public void close() {
this.producerReady = false;
producer.close(Duration.ofMillis(0)); // close immediately
}

private KafkaProducer<String, byte[]> generateProducer(String brokerIP, int lingerMS, int batchSize) {
if (producer != null)
close(); // close existing producer connection

int requestTimeoutMs = 5000;
Properties kafkaProps = new Properties();
kafkaProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerIP);
kafkaProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG,
"org.apache.kafka.common.serialization.ByteArraySerializer");
kafkaProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class);
kafkaProps.put(ProducerConfig.BATCH_SIZE_CONFIG, batchSize);
kafkaProps.put(ProducerConfig.LINGER_MS_CONFIG, lingerMS);
kafkaProps.put(ProducerConfig.RETRIES_CONFIG, 0);
kafkaProps.put(ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG, requestTimeoutMs);
kafkaProps.put(ProducerConfig.DELIVERY_TIMEOUT_MS_CONFIG, lingerMS + requestTimeoutMs);
return new KafkaProducer<String, byte[]>(kafkaProps);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
import org.apache.kafka.clients.consumer.Consumer;
import org.apache.kafka.clients.consumer.ConsumerRecords;
import org.apache.kafka.clients.consumer.KafkaConsumer;
import org.apache.kafka.clients.consumer.ConsumerConfig;
import org.apache.kafka.common.serialization.StringDeserializer;

public abstract class AbstractKafkaConsumerTask implements Task {
public abstract class AbstractKafkaConsumerTask<V> implements Task {

protected Consumer<String, String> kafkaConsumer;
protected Consumer<String, V> kafkaConsumer;
protected KafkaConfig kafkaConfig;
protected String kafkaTopic;

Expand All @@ -24,9 +26,16 @@ public AbstractKafkaConsumerTask(KafkaConfig kafkaConfig, String kafkaTopic) {
String kafkaBrokerUrl = kafkaConfig.getBootstrapServers();
String groupId = kafkaConfig.getGroupId();

Properties properties =
Utils.configProperties(
kafkaBrokerUrl, groupId, kafkaConfig.getConsumerConfig().getMaxPollRecords());
Properties properties = new Properties();
properties.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, kafkaBrokerUrl);
properties.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class);
properties.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG,
"org.apache.kafka.common.serialization.ByteArrayDeserializer");
properties.put(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, kafkaConfig.getConsumerConfig().getMaxPollRecords());
properties.put(ConsumerConfig.GROUP_ID_CONFIG, kafkaConfig.getGroupId());
properties.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest");
properties.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false);

this.kafkaConsumer = new KafkaConsumer<>(properties);
}

Expand All @@ -40,9 +49,8 @@ public void run() {
() -> {
// Poll data from Kafka topic
while (true) {
ConsumerRecords<String, String> records =
kafkaConsumer.poll(
Duration.ofMillis(kafkaConfig.getConsumerConfig().getPollDurationMilli()));
ConsumerRecords<String, V> records = kafkaConsumer.poll(
Duration.ofMillis(kafkaConfig.getConsumerConfig().getPollDurationMilli()));
if (records.isEmpty()) {
continue;
}
Expand All @@ -60,5 +68,5 @@ public void run() {
});
}

abstract void processRecords(ConsumerRecords<String, String> records);
abstract void processRecords(ConsumerRecords<String, V> records);
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.akto.dto.type.URLMethods;
import com.akto.kafka.KafkaConfig;
import com.akto.proto.generated.threat_detection.message.sample_request.v1.SampleMaliciousRequest;
import com.akto.proto.generated.threat_detection.message.sample_request.v1.SampleRequestKafkaEnvelope;
import com.akto.threat.detection.db.entity.MaliciousEventEntity;
import com.akto.threat.detection.dto.MessageEnvelope;
import com.google.protobuf.InvalidProtocolBufferException;
Expand All @@ -17,7 +18,7 @@
/*
This will read sample malicious data from kafka topic and save it to DB.
*/
public class FlushSampleDataTask extends AbstractKafkaConsumerTask {
public class FlushSampleDataTask extends AbstractKafkaConsumerTask<byte[]> {

private final SessionFactory sessionFactory;

Expand All @@ -27,37 +28,29 @@ public FlushSampleDataTask(
this.sessionFactory = sessionFactory;
}

protected void processRecords(ConsumerRecords<String, String> records) {
protected void processRecords(ConsumerRecords<String, byte[]> records) {
List<MaliciousEventEntity> events = new ArrayList<>();
records.forEach(
r -> {
String message = r.value();
SampleMaliciousRequest.Builder builder = SampleMaliciousRequest.newBuilder();
MessageEnvelope m = MessageEnvelope.unmarshal(message).orElse(null);
if (m == null) {
return;
}

SampleRequestKafkaEnvelope envelope;
try {
JsonFormat.parser().merge(m.getData(), builder);
envelope = SampleRequestKafkaEnvelope.parseFrom(r.value());
SampleMaliciousRequest evt = envelope.getMaliciousRequest();

events.add(
MaliciousEventEntity.newBuilder()
.setActor(envelope.getActor())
.setFilterId(evt.getFilterId())
.setUrl(evt.getUrl())
.setMethod(URLMethods.Method.fromString(evt.getMethod()))
.setTimestamp(evt.getTimestamp())
.setOrig(evt.getPayload())
.setApiCollectionId(evt.getApiCollectionId())
.setIp(evt.getIp())
.build());
} catch (InvalidProtocolBufferException e) {
e.printStackTrace();
return;
}

SampleMaliciousRequest evt = builder.build();

events.add(
MaliciousEventEntity.newBuilder()
.setActor(m.getActor())
.setFilterId(evt.getFilterId())
.setUrl(evt.getUrl())
.setMethod(URLMethods.Method.fromString(evt.getMethod()))
.setTimestamp(evt.getTimestamp())
.setOrig(evt.getPayload())
.setApiCollectionId(evt.getApiCollectionId())
.setIp(evt.getIp())
.build());
});

Session session = this.sessionFactory.openSession();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
import com.akto.dto.test_editor.YamlTemplate;
import com.akto.dto.type.URLMethods;
import com.akto.hybrid_parsers.HttpCallParser;
import com.akto.kafka.Kafka;
import com.akto.kafka.KafkaConfig;
import com.akto.proto.generated.threat_detection.message.malicious_event.event_type.v1.EventType;
import com.akto.proto.generated.threat_detection.message.malicious_event.v1.MaliciousEventKafkaEnvelope;
import com.akto.proto.generated.threat_detection.message.malicious_event.v1.MaliciousEventMessage;
import com.akto.proto.generated.threat_detection.message.sample_request.v1.SampleMaliciousRequest;
import com.akto.proto.generated.threat_detection.message.sample_request.v1.SampleRequestKafkaEnvelope;
import com.akto.rules.TestPlugin;
import com.akto.runtime.utils.Utils;
import com.akto.test_editor.execution.VariableResolver;
Expand All @@ -27,6 +28,7 @@
import com.akto.threat.detection.cache.RedisBackedCounterCache;
import com.akto.threat.detection.constants.KafkaTopic;
import com.akto.threat.detection.dto.MessageEnvelope;
import com.akto.threat.detection.kafka.KafkaProtoProducer;
import com.akto.threat.detection.smart_event_detector.window_based.WindowBasedThresholdNotifier;
import com.google.protobuf.InvalidProtocolBufferException;
import io.lettuce.core.RedisClient;
Expand Down Expand Up @@ -54,7 +56,7 @@ public class MaliciousTrafficDetectorTask implements Task {
private int filterLastUpdatedAt = 0;
private int filterUpdateIntervalSec = 300;

private final Kafka internalKafka;
private final KafkaProtoProducer internalKafka;

private static final DataActor dataActor = DataActorFactory.fetchInstance();

Expand All @@ -77,11 +79,7 @@ public MaliciousTrafficDetectorTask(
new RedisBackedCounterCache(redisClient, "wbt"),
new WindowBasedThresholdNotifier.Config(100, 10 * 60));

this.internalKafka =
new Kafka(
internalConfig.getBootstrapServers(),
internalConfig.getProducerConfig().getLingerMs(),
internalConfig.getProducerConfig().getBatchSize());
this.internalKafka = new KafkaProtoProducer(internalConfig);
}

public void run() {
Expand Down Expand Up @@ -123,19 +121,20 @@ private Map<String, FilterConfig> getFilters() {
return apiFilters;
}

private boolean validateFilterForRequest(FilterConfig apiFilter, RawApi rawApi, ApiInfo.ApiInfoKey apiInfoKey, String message) {
private boolean validateFilterForRequest(
FilterConfig apiFilter, RawApi rawApi, ApiInfo.ApiInfoKey apiInfoKey, String message) {
try {
System.out.println("using buildFromMessageNew func");

Map<String, Object> varMap = apiFilter.resolveVarMap();
VariableResolver.resolveWordList(
varMap,
new HashMap<ApiInfo.ApiInfoKey, List<String>>() {
{
put(apiInfoKey, Collections.singletonList(message));
}
},
apiInfoKey);
varMap,
new HashMap<ApiInfo.ApiInfoKey, List<String>>() {
{
put(apiInfoKey, Collections.singletonList(message));
}
},
apiInfoKey);

String filterExecutionLogId = UUID.randomUUID().toString();
ValidationResult res =
Expand All @@ -160,7 +159,7 @@ private void processRecord(ConsumerRecord<String, String> record) {
return;
}

List<MessageEnvelope> maliciousMessages = new ArrayList<>();
List<SampleRequestKafkaEnvelope> maliciousMessages = new ArrayList<>();

System.out.println("Total number of filters: " + filters.size());

Expand Down Expand Up @@ -212,13 +211,12 @@ private void processRecord(ConsumerRecord<String, String> record) {
.setFilterId(apiFilter.getId())
.build();

try {
maliciousMessages.add(
MessageEnvelope.generateEnvelope(
responseParam.getAccountId(), actor, maliciousReq));
} catch (InvalidProtocolBufferException e) {
return;
}
maliciousMessages.add(
SampleRequestKafkaEnvelope.newBuilder()
.setActor(actor)
.setAccountId(responseParam.getAccountId())
.setMaliciousRequest(maliciousReq)
.build());

if (!isAggFilter) {
generateAndPushMaliciousEventRequest(
Expand Down Expand Up @@ -250,12 +248,7 @@ private void processRecord(ConsumerRecord<String, String> record) {
try {
maliciousMessages.forEach(
sample -> {
sample
.marshal()
.ifPresent(
data -> {
internalKafka.send(data, KafkaTopic.ThreatDetection.MALICIOUS_EVENTS);
});
internalKafka.send(KafkaTopic.ThreatDetection.MALICIOUS_EVENTS, sample);
});
} catch (Exception e) {
e.printStackTrace();
Expand All @@ -281,12 +274,18 @@ private void generateAndPushMaliciousEventRequest(
.setDetectedAt(responseParam.getTime())
.build();
try {
System.out.println("Pushing malicious event to kafka: ");
System.out.println("Pushing malicious event to kafka: " + maliciousEvent);
MaliciousEventKafkaEnvelope envelope =
MaliciousEventKafkaEnvelope.newBuilder()
.setActor(actor)
.setAccountId(responseParam.getAccountId())
.setMaliciousEvent(maliciousEvent)
.build();
MessageEnvelope.generateEnvelope(responseParam.getAccountId(), actor, maliciousEvent)
.marshal()
.ifPresent(
data -> {
internalKafka.send(data, KafkaTopic.ThreatDetection.ALERTS);
internalKafka.send(KafkaTopic.ThreatDetection.ALERTS, envelope);
});
} catch (InvalidProtocolBufferException e) {
e.printStackTrace();
Expand Down
Loading
Loading