Skip to content

Commit

Permalink
Merge pull request #1112 from akto-api-security/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
shivam-rawat-akto authored May 15, 2024
2 parents 87c6519 + fe851a9 commit 6614cfd
Show file tree
Hide file tree
Showing 11 changed files with 416 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import com.mongodb.client.model.Updates;
import com.mongodb.BasicDBObject;
import com.mongodb.client.model.Sorts;
import com.mongodb.client.model.UnwindOptions;
import com.opensymphony.xwork2.Action;
import org.apache.commons.lang3.tuple.Pair;
import org.bson.Document;
Expand Down Expand Up @@ -426,7 +427,15 @@ public String fetchLastSeenInfoInCollections(){
public String fetchRiskScoreInfo(){
Map<Integer, Double> riskScoreMap = new HashMap<>();
List<Bson> pipeline = new ArrayList<>();
BasicDBObject groupId = new BasicDBObject("apiCollectionId", "$_id.apiCollectionId");

/*
* Use Unwind to unwind the collectionIds field resulting in a document for each collectionId in the collectionIds array
*/
UnwindOptions unwindOptions = new UnwindOptions();
unwindOptions.preserveNullAndEmptyArrays(false);
pipeline.add(Aggregates.unwind("$collectionIds", unwindOptions));

BasicDBObject groupId = new BasicDBObject("apiCollectionId", "$collectionIds");
pipeline.add(Aggregates.sort(
Sorts.descending(ApiInfo.RISK_SCORE)
));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1150,6 +1150,34 @@ public static void createLoginSignupGroups(BackwardCompatibility backwardCompati
}
}

public static void createRiskScoreApiGroup(int id, String name, RiskScoreTestingEndpoints.RiskScoreGroupType riskScoreGroupType) {
loggerMaker.infoAndAddToDb("Creating risk score group: " + name, LogDb.DASHBOARD);

ApiCollection riskScoreGroup = new ApiCollection(id, name, Context.now(), new HashSet<>(), null, 0, false, false);

List<TestingEndpoints> riskScoreConditions = new ArrayList<>();
RiskScoreTestingEndpoints riskScoreTestingEndpoints = new RiskScoreTestingEndpoints(riskScoreGroupType);
riskScoreConditions.add(riskScoreTestingEndpoints);

riskScoreGroup.setConditions(riskScoreConditions);
riskScoreGroup.setType(ApiCollection.Type.API_GROUP);

ApiCollectionsDao.instance.insertOne(riskScoreGroup);
}

public static void createRiskScoreGroups(BackwardCompatibility backwardCompatibility) {
if (backwardCompatibility.getRiskScoreGroups() == 0) {
createRiskScoreApiGroup(111_111_148, "Low Risk APIs", RiskScoreTestingEndpoints.RiskScoreGroupType.LOW);
createRiskScoreApiGroup(111_111_149, "Medium Risk APIs", RiskScoreTestingEndpoints.RiskScoreGroupType.MEDIUM);
createRiskScoreApiGroup(111_111_150, "High Risk APIs", RiskScoreTestingEndpoints.RiskScoreGroupType.HIGH);

BackwardCompatibilityDao.instance.updateOne(
Filters.eq("_id", backwardCompatibility.getId()),
Updates.set(BackwardCompatibility.RISK_SCORE_GROUPS, Context.now())
);
}
}

public static void dropWorkflowTestResultCollection(BackwardCompatibility backwardCompatibility) {
if (backwardCompatibility.getDropWorkflowTestResult() == 0) {
WorkflowTestResultsDao.instance.getMCollection().drop();
Expand Down Expand Up @@ -1538,7 +1566,22 @@ public void updateCustomCollections() {
List<ApiCollection> apiCollections = ApiCollectionsDao.instance.findAll(new BasicDBObject());
for (ApiCollection apiCollection : apiCollections) {
if (ApiCollection.Type.API_GROUP.equals(apiCollection.getType())) {
ApiCollectionUsers.computeCollectionsForCollectionId(apiCollection.getConditions(), apiCollection.getId());
List<TestingEndpoints> conditions = apiCollection.getConditions();

// Don't update API groups that are delta update based
boolean isDeltaUpdateBasedApiGroup = false;
for (TestingEndpoints testingEndpoints : conditions) {
if (TestingEndpoints.checkDeltaUpdateBased(testingEndpoints.getType())) {
isDeltaUpdateBasedApiGroup = true;
break;
}
}

if (isDeltaUpdateBasedApiGroup) {
continue;
}

ApiCollectionUsers.computeCollectionsForCollectionId(conditions, apiCollection.getId());
}
}
}
Expand Down Expand Up @@ -1983,6 +2026,7 @@ public static void setBackwardCompatibilities(BackwardCompatibility backwardComp
dropAuthMechanismData(backwardCompatibility);
moveAuthMechanismDataToRole(backwardCompatibility);
createLoginSignupGroups(backwardCompatibility);
createRiskScoreGroups(backwardCompatibility);
deleteAccessListFromApiToken(backwardCompatibility);
deleteNullSubCategoryIssues(backwardCompatibility);
enableNewMerging(backwardCompatibility);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
import com.akto.dto.AccountSettings;
import com.akto.dto.ApiInfo;
import com.akto.dto.AktoDataType;
import com.akto.dto.ApiCollectionUsers;
import com.akto.dto.ApiInfo.ApiInfoKey;
import com.akto.dto.CustomDataType;
import com.akto.dto.test_run_findings.TestingIssuesId;
import com.akto.dto.test_run_findings.TestingRunIssues;
import com.akto.dto.testing.RiskScoreTestingEndpoints;
import com.akto.dto.testing.TestingEndpoints;
import com.akto.dto.testing.RiskScoreTestingEndpoints.RiskScoreGroupType;
import com.akto.dto.type.SingleTypeInfo;
import com.akto.log.LoggerMaker;
import com.akto.log.LoggerMaker.LogDb;
Expand Down Expand Up @@ -140,7 +144,9 @@ public void updateSeverityScoreInApiInfo(int timeStampFilter){
return ;
}

Map<ApiInfoKey, Float> severityScoreMap = getSeverityScoreMap(updatedIssues);
Map<ApiInfoKey, Float> severityScoreMap = getSeverityScoreMap(updatedIssues);

RiskScoreTestingEndpointsUtils riskScoreTestingEndpointsUtils = new RiskScoreTestingEndpointsUtils();

// after getting the severityScoreMap, we write that in DB
if(severityScoreMap != null){
Expand All @@ -149,7 +155,13 @@ public void updateSeverityScoreInApiInfo(int timeStampFilter){
ApiInfo apiInfo = ApiInfoDao.instance.findOne(filter);
boolean isSensitive = apiInfo != null ? apiInfo.getIsSensitive() : false;
float riskScore = ApiInfoDao.getRiskScore(apiInfo, isSensitive, Utils.getRiskScoreValueFromSeverityScore(severityScore));


if (apiInfo != null) {
if (apiInfo.getRiskScore() != riskScore) {
riskScoreTestingEndpointsUtils.updateApiRiskScoreGroup(apiInfo, riskScore);
}
}

Bson update = Updates.combine(
Updates.set(ApiInfo.SEVERITY_SCORE, severityScore),
Updates.set(ApiInfo.RISK_SCORE, riskScore)
Expand All @@ -161,7 +173,8 @@ public void updateSeverityScoreInApiInfo(int timeStampFilter){
if (bulkUpdatesForApiInfo.size() > 0) {
ApiInfoDao.instance.getMCollection().bulkWrite(bulkUpdatesForApiInfo, new BulkWriteOptions().ordered(false));
}


riskScoreTestingEndpointsUtils.syncRiskScoreGroupApis();
}

private static void writeUpdatesForSensitiveInfoInApiInfo(List<String> updatedDataTypes, int timeStampFilter){
Expand Down Expand Up @@ -258,6 +271,9 @@ public void calculateRiskScoreForAllApis() {
Filters.lte(ApiInfo.LAST_CALCULATED_TIME, timeStamp)
);
Bson projection = Projections.include("_id", ApiInfo.API_ACCESS_TYPES, ApiInfo.LAST_SEEN, ApiInfo.SEVERITY_SCORE, ApiInfo.IS_SENSITIVE);

RiskScoreTestingEndpointsUtils riskScoreTestingEndpointsUtils = new RiskScoreTestingEndpointsUtils();

while(count < 100){
List<ApiInfo> apiInfos = ApiInfoDao.instance.findAll(filter,0, limit, Sorts.descending(ApiInfo.LAST_CALCULATED_TIME), projection);
for(ApiInfo apiInfo: apiInfos){
Expand All @@ -269,6 +285,19 @@ public void calculateRiskScoreForAllApis() {
Bson filterQ = ApiInfoDao.getFilter(apiInfo.getId());

bulkUpdates.add(new UpdateManyModel<>(filterQ, update, new UpdateOptions().upsert(false)));

List<Integer> collectionIds = apiInfo.getCollectionIds();
float oldRiskScore = apiInfo.getRiskScore();
RiskScoreTestingEndpoints.RiskScoreGroupType oldRiskScoreGroupType = RiskScoreTestingEndpoints.calculateRiskScoreGroup(oldRiskScore);
int oldRiskScoreGroupCollectionId = RiskScoreTestingEndpoints.getApiCollectionId(oldRiskScoreGroupType);

if (!collectionIds.contains(oldRiskScoreGroupCollectionId)) {
// Add API to risk score API group if it is not already added
riskScoreTestingEndpointsUtils.updateApiRiskScoreGroup(apiInfo, riskScore);
} else if (oldRiskScore != riskScore) {
// Update API in risk score API group if risk score has changed
riskScoreTestingEndpointsUtils.updateApiRiskScoreGroup(apiInfo, riskScore);
}
}
if(bulkUpdates.size() > 0){
ApiInfoDao.instance.bulkWrite(bulkUpdates, new BulkWriteOptions().ordered(false));
Expand All @@ -279,5 +308,7 @@ public void calculateRiskScoreForAllApis() {
break;
}
}

riskScoreTestingEndpointsUtils.syncRiskScoreGroupApis();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package com.akto.utils;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;

import com.akto.dao.context.Context;
import com.akto.dto.ApiCollectionUsers;
import com.akto.dto.ApiInfo;
import com.akto.dto.testing.RiskScoreTestingEndpoints;
import com.akto.dto.testing.TestingEndpoints;
import com.akto.log.LoggerMaker;
import com.akto.log.LoggerMaker.LogDb;

public class RiskScoreTestingEndpointsUtils {
private static final LoggerMaker loggerMaker = new LoggerMaker(RiskScoreTestingEndpointsUtils.class);

private Map<RiskScoreTestingEndpoints.RiskScoreGroupType, List<ApiInfo>> removeApisFromRiskScoreGroupMap = new HashMap<RiskScoreTestingEndpoints.RiskScoreGroupType, List<ApiInfo>>() {{
put(RiskScoreTestingEndpoints.RiskScoreGroupType.LOW, new ArrayList<>());
put(RiskScoreTestingEndpoints.RiskScoreGroupType.MEDIUM, new ArrayList<>());
put(RiskScoreTestingEndpoints.RiskScoreGroupType.HIGH, new ArrayList<>());
}};


private Map<RiskScoreTestingEndpoints.RiskScoreGroupType, List<ApiInfo>> addApisToRiskScoreGroupMap = new HashMap<RiskScoreTestingEndpoints.RiskScoreGroupType, List<ApiInfo>>() {{
put(RiskScoreTestingEndpoints.RiskScoreGroupType.LOW, new ArrayList<>());
put(RiskScoreTestingEndpoints.RiskScoreGroupType.MEDIUM, new ArrayList<>());
put(RiskScoreTestingEndpoints.RiskScoreGroupType.HIGH, new ArrayList<>());
}};

private static final ExecutorService executorService = Executors.newFixedThreadPool(1);

public RiskScoreTestingEndpointsUtils() {
}

public void updateApiRiskScoreGroup(ApiInfo apiInfo, float updatedRiskScore) {
float oldRiskScore = apiInfo.getRiskScore();

RiskScoreTestingEndpoints.RiskScoreGroupType removeRiskScoreGroupType = RiskScoreTestingEndpoints.calculateRiskScoreGroup(oldRiskScore);
removeApisFromRiskScoreGroupMap.get(removeRiskScoreGroupType).add(apiInfo);

RiskScoreTestingEndpoints.RiskScoreGroupType addRiskScoreGroupType = RiskScoreTestingEndpoints.calculateRiskScoreGroup(updatedRiskScore);
addApisToRiskScoreGroupMap.get(addRiskScoreGroupType).add(apiInfo);
}

private void updateRiskScoreApiGroups() {
try {
for(RiskScoreTestingEndpoints.RiskScoreGroupType riskScoreGroupType: RiskScoreTestingEndpoints.RiskScoreGroupType.values()) {
RiskScoreTestingEndpoints riskScoreTestingEndpoints = new RiskScoreTestingEndpoints(riskScoreGroupType);

List<TestingEndpoints> testingEndpoints = new ArrayList<>();
testingEndpoints.add(riskScoreTestingEndpoints);
int apiCollectionId = RiskScoreTestingEndpoints.getApiCollectionId(riskScoreGroupType);

// Remove APIs from the original risk score group
List<ApiInfo> removeApisFromRiskScoreGroupList = removeApisFromRiskScoreGroupMap.get(riskScoreGroupType);
loggerMaker.infoAndAddToDb("Removing " + removeApisFromRiskScoreGroupList.size() + " APIs from risk score group - " + riskScoreGroupType, LogDb.DASHBOARD);
for (int start = 0; start < removeApisFromRiskScoreGroupList.size(); start += RiskScoreTestingEndpoints.BATCH_SIZE) {
int end = Math.min(start + RiskScoreTestingEndpoints.BATCH_SIZE, removeApisFromRiskScoreGroupList.size());

List<ApiInfo> batch = removeApisFromRiskScoreGroupList.subList(start, end);

riskScoreTestingEndpoints.setFilterRiskScoreGroupApis(batch);
ApiCollectionUsers.removeFromCollectionsForCollectionId(testingEndpoints, apiCollectionId);
}

// Add APIs to the new risk score group
List<ApiInfo> addApisToRiskScoreGroupList = addApisToRiskScoreGroupMap.get(riskScoreGroupType);
loggerMaker.infoAndAddToDb("Adding " + addApisToRiskScoreGroupList.size() + " APIs to risk score group - " + riskScoreGroupType, LogDb.DASHBOARD);
for (int start = 0; start < addApisToRiskScoreGroupList.size(); start += RiskScoreTestingEndpoints.BATCH_SIZE) {
int end = Math.min(start + RiskScoreTestingEndpoints.BATCH_SIZE, addApisToRiskScoreGroupList.size());

List<ApiInfo> batch = addApisToRiskScoreGroupList.subList(start, end);

riskScoreTestingEndpoints.setFilterRiskScoreGroupApis(batch);
ApiCollectionUsers.addToCollectionsForCollectionId(testingEndpoints, apiCollectionId);
}
}
} catch (Exception e) {
loggerMaker.errorAndAddToDb("Error updating risk score group APIs - " + e.getMessage(), LogDb.DASHBOARD);
}
}

public void syncRiskScoreGroupApis() {
int accountId = Context.accountId.get();

try {
executorService.submit(() -> {
Context.accountId.set(accountId);
loggerMaker.infoAndAddToDb("Updating risk score API groups", LogDb.DASHBOARD);
updateRiskScoreApiGroups();
});
} catch (Exception e) {
loggerMaker.errorAndAddToDb("Error syncing risk score group APIs - " + e.getMessage(), LogDb.DASHBOARD);
}
}
}
4 changes: 3 additions & 1 deletion libs/dao/src/main/java/com/akto/DaoInit.java
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ public static CodecRegistry createCodecRegistry(){
ClassModel<CodeAnalysisApiLocation> codeAnalysisApiLocationClassModel = ClassModel.builder(CodeAnalysisApiLocation.class).enableDiscriminator(true).build();
ClassModel<CodeAnalysisApiInfo> codeAnalysisApiInfoClassModel = ClassModel.builder(CodeAnalysisApiInfo.class).enableDiscriminator(true).build();
ClassModel<CodeAnalysisApiInfo.CodeAnalysisApiInfoKey> codeAnalysisApiInfoKeyClassModel = ClassModel.builder(CodeAnalysisApiInfo.CodeAnalysisApiInfoKey.class).enableDiscriminator(true).build();
ClassModel<RiskScoreTestingEndpoints> riskScoreTestingEndpointsClassModel = ClassModel.builder(RiskScoreTestingEndpoints.class).enableDiscriminator(true).build();

CodecRegistry pojoCodecRegistry = fromProviders(PojoCodecProvider.builder().register(
configClassModel, signupInfoClassModel, apiAuthClassModel, attempResultModel, urlTemplateModel,
Expand Down Expand Up @@ -271,7 +272,8 @@ public static CodecRegistry createCodecRegistry(){
UsageMetricClassModel, UsageMetricInfoClassModel, UsageSyncClassModel, OrganizationClassModel,
yamlNodeDetails, multiExecTestResultClassModel, workflowTestClassModel, dependencyNodeClassModel, paramInfoClassModel,
nodeClassModel, connectionClassModel, edgeClassModel, replaceDetailClassModel, modifyHostDetailClassModel, fileUploadClassModel
,fileUploadLogClassModel, codeAnalysisCollectionClassModel, codeAnalysisApiLocationClassModel, codeAnalysisApiInfoClassModel, codeAnalysisApiInfoKeyClassModel).automatic(true).build());
,fileUploadLogClassModel, codeAnalysisCollectionClassModel, codeAnalysisApiLocationClassModel, codeAnalysisApiInfoClassModel, codeAnalysisApiInfoKeyClassModel,
riskScoreTestingEndpointsClassModel).automatic(true).build());

final CodecRegistry customEnumCodecs = CodecRegistries.fromCodecs(
new EnumCodec<>(Conditions.Operator.class),
Expand Down
14 changes: 12 additions & 2 deletions libs/dao/src/main/java/com/akto/dao/ApiInfoDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import com.mongodb.client.model.Filters;
import com.mongodb.client.model.Projections;
import com.mongodb.client.model.Sorts;
import com.mongodb.client.model.UnwindOptions;
import com.mongodb.client.model.Updates;

import org.bson.Document;
Expand Down Expand Up @@ -91,7 +92,11 @@ public Map<Integer,Integer> getCoverageCount(){
int oneMonthAgo = Context.now() - Constants.ONE_MONTH_TIMESTAMP ;
pipeline.add(Aggregates.match(Filters.gte("lastTested", oneMonthAgo)));

BasicDBObject groupedId2 = new BasicDBObject("apiCollectionId", "$_id.apiCollectionId");
UnwindOptions unwindOptions = new UnwindOptions();
unwindOptions.preserveNullAndEmptyArrays(false);
pipeline.add(Aggregates.unwind("$collectionIds", unwindOptions));

BasicDBObject groupedId2 = new BasicDBObject("apiCollectionId", "$collectionIds");
pipeline.add(Aggregates.group(groupedId2, Accumulators.sum("count",1)));
pipeline.add(Aggregates.project(
Projections.fields(
Expand All @@ -115,7 +120,12 @@ public Map<Integer,Integer> getCoverageCount(){
public Map<Integer,Integer> getLastTrafficSeen(){
Map<Integer,Integer> result = new HashMap<>();
List<Bson> pipeline = new ArrayList<>();
BasicDBObject groupedId = new BasicDBObject("apiCollectionId", "$_id.apiCollectionId");

UnwindOptions unwindOptions = new UnwindOptions();
unwindOptions.preserveNullAndEmptyArrays(false);
pipeline.add(Aggregates.unwind("$collectionIds", unwindOptions));

BasicDBObject groupedId = new BasicDBObject("apiCollectionId", "$collectionIds");
pipeline.add(Aggregates.sort(Sorts.orderBy(Sorts.descending(ApiInfo.ID_API_COLLECTION_ID), Sorts.descending(ApiInfo.LAST_SEEN))));
pipeline.add(Aggregates.group(groupedId, Accumulators.first(ApiInfo.LAST_SEEN, "$lastSeen")));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.mongodb.client.model.Aggregates;
import com.mongodb.client.model.Filters;
import com.mongodb.client.model.Projections;
import com.mongodb.client.model.UnwindOptions;

public class TestingRunIssuesDao extends AccountsContextDao<TestingRunIssues> {

Expand Down Expand Up @@ -61,7 +62,11 @@ public Map<Integer,Map<String,Integer>> getSeveritiesMapForCollections(){
List<Bson> pipeline = new ArrayList<>();
pipeline.add(Aggregates.match(Filters.eq(TestingRunIssues.TEST_RUN_ISSUES_STATUS, "OPEN")));

BasicDBObject groupedId = new BasicDBObject("apiCollectionId", "$_id.apiInfoKey.apiCollectionId")
UnwindOptions unwindOptions = new UnwindOptions();
unwindOptions.preserveNullAndEmptyArrays(false);
pipeline.add(Aggregates.unwind("$collectionIds", unwindOptions));

BasicDBObject groupedId = new BasicDBObject("apiCollectionId", "$collectionIds")
.append("severity", "$severity") ;

pipeline.add(Aggregates.group(groupedId, Accumulators.sum("count", 1)));
Expand Down
Loading

0 comments on commit 6614cfd

Please sign in to comment.