Skip to content

Commit

Permalink
sending sample malicious events to backend only once
Browse files Browse the repository at this point in the history
  • Loading branch information
ag060 committed Dec 23, 2024
1 parent e273f0b commit f640fdd
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ public long getValue() {

private final Cache<String, Long> localCache;

private final ConcurrentLinkedQueue<Op> pendingOps;
private final ConcurrentLinkedQueue<Op> pendingIncOps;
private final ConcurrentMap<String, Boolean> deletedKeys;
private final String prefix;

public RedisBackedCounterCache(RedisClient redisClient, String prefix) {
Expand All @@ -44,7 +45,8 @@ public RedisBackedCounterCache(RedisClient redisClient, String prefix) {
ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor();
executor.scheduleAtFixedRate(this::syncToRedis, 60, 5, TimeUnit.SECONDS);

this.pendingOps = new ConcurrentLinkedQueue<>();
this.pendingIncOps = new ConcurrentLinkedQueue<>();
this.deletedKeys = new ConcurrentHashMap<>();
}

private String addPrefixToKey(String key) {
Expand All @@ -60,7 +62,7 @@ public void increment(String key) {
public void incrementBy(String key, long val) {
String _key = addPrefixToKey(key);
localCache.asMap().merge(_key, val, Long::sum);
pendingOps.add(new Op(_key, val));
pendingIncOps.add(new Op(_key, val));

this.setExpiryIfNotSet(_key, 3 * 60 * 60); // added 3 hours expiry for now
}
Expand All @@ -77,8 +79,10 @@ public boolean exists(String key) {

@Override
public void clear(String key) {
localCache.invalidate(addPrefixToKey(key));
redis.async().del(addPrefixToKey(key));
String _key = addPrefixToKey(key);
localCache.invalidate(_key);
this.deletedKeys.put(_key, true);
redis.async().del(_key);
}

private void setExpiryIfNotSet(String key, long seconds) {
Expand All @@ -89,10 +93,15 @@ private void setExpiryIfNotSet(String key, long seconds) {
}

private void syncToRedis() {
while (!pendingOps.isEmpty()) {
Op op = pendingOps.poll();
while (!pendingIncOps.isEmpty()) {
Op op = pendingIncOps.poll();
String key = op.getKey();
long val = op.getValue();

if (this.deletedKeys.containsKey(key)) {
continue;
}

redis
.async()
.incrby(key, val)
Expand All @@ -107,5 +116,7 @@ private void syncToRedis() {
}
});
}

this.deletedKeys.clear();
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package com.akto.threat.detection.db.entity;

import com.akto.dto.type.URLMethods;

import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.util.UUID;

import javax.persistence.Column;
import javax.persistence.Entity;
import javax.persistence.EnumType;
Expand All @@ -14,7 +12,6 @@
import javax.persistence.Id;
import javax.persistence.PrePersist;
import javax.persistence.Table;

import org.hibernate.annotations.GenericGenerator;

@Entity
Expand Down Expand Up @@ -55,6 +52,9 @@ public class MaliciousEventEntity {
@Column(name = "created_at", updatable = false)
private LocalDateTime createdAt;

@Column(name = "_alerted_to_backend")
private boolean alertedToBackend;

public MaliciousEventEntity() {}

@PrePersist
Expand Down Expand Up @@ -172,6 +172,10 @@ public LocalDateTime getCreatedAt() {
return createdAt;
}

public boolean isAlertedToBackend() {
return alertedToBackend;
}

@Override
public String toString() {
return "MaliciousEventEntity{"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,14 @@ public void run() {
continue;
}

processRecords(records);
try {
processRecords(records);

if (!records.isEmpty()) {
kafkaConsumer.commitSync();
if (!records.isEmpty()) {
kafkaConsumer.commitSync();
}
} catch (Exception ex) {
ex.printStackTrace();
}
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ContentType;
Expand All @@ -39,14 +40,32 @@ public SendMaliciousEventsToBackend(
this.httpClient = HttpClients.createDefault();
}

private void markSampleDataAsSent(List<UUID> ids) {
Session session = this.sessionFactory.openSession();
Transaction txn = session.beginTransaction();
try {
session
.createQuery(
"update MaliciousEventEntity m set m.alertedToBackend = true where m.id in :ids")
.setParameterList("ids", ids)
.executeUpdate();
} catch (Exception ex) {
ex.printStackTrace();
txn.rollback();
} finally {
txn.commit();
session.close();
}
}

private List<MaliciousEventEntity> getSampleMaliciousRequests(String actor, String filterId) {
Session session = this.sessionFactory.openSession();
Transaction txn = session.beginTransaction();
try {
return session
.createQuery(
"from MaliciousEventEntity m where m.actor = :actor and m.filterId = :filterId order"
+ " by m.createdAt desc",
"from MaliciousEventEntity m where m.actor = :actor and m.filterId = :filterId and"
+ " m.alertedToBackend = false order by m.createdAt desc",
MaliciousEventEntity.class)
.setParameter("actor", actor)
.setParameter("filterId", filterId)
Expand Down Expand Up @@ -83,12 +102,13 @@ protected void processRecords(ConsumerRecords<String, String> records) {
MaliciousEventMessage evt = builder.build();

// Get sample data from postgres for this alert
List<MaliciousEventEntity> sampleData =
this.getSampleMaliciousRequests(evt.getActor(), evt.getFilterId());
try {
RecordMaliciousEventRequest.Builder reqBuilder =
RecordMaliciousEventRequest.newBuilder().setMaliciousEvent(evt);
if (EventType.EVENT_TYPE_AGGREGATED.equals(evt.getEventType())) {
List<MaliciousEventEntity> sampleData =
this.getSampleMaliciousRequests(evt.getActor(), evt.getFilterId());
sampleData = this.getSampleMaliciousRequests(evt.getActor(), evt.getFilterId());

reqBuilder.addAllSampleRequests(
sampleData.stream()
Expand All @@ -105,9 +125,13 @@ protected void processRecords(ConsumerRecords<String, String> records) {
.collect(Collectors.toList()));
}

List<UUID> sampleIds =
sampleData.stream().map(MaliciousEventEntity::getId).collect(Collectors.toList());

RecordMaliciousEventRequest maliciousEventRequest = reqBuilder.build();
String url = System.getenv("AKTO_THREAT_PROTECTION_BACKEND_URL");
String token = System.getenv("AKTO_THREAT_PROTECTION_BACKEND_TOKEN");
ProtoMessageUtils.toString(reqBuilder.build())
ProtoMessageUtils.toString(maliciousEventRequest)
.ifPresent(
msg -> {
StringEntity requestEntity =
Expand All @@ -123,6 +147,10 @@ protected void processRecords(ConsumerRecords<String, String> records) {
} catch (IOException e) {
e.printStackTrace();
}

if (!sampleIds.isEmpty()) {
markSampleDataAsSent(sampleIds);
}
});
} catch (Exception e) {
e.printStackTrace();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
alter table threat_detection.malicious_event add column _alerted_to_backend boolean default false;

-- set all existing rows to false
update threat_detection.malicious_event set _alerted_to_backend = false;

0 comments on commit f640fdd

Please sign in to comment.