Skip to content

Commit

Permalink
Reset iast request context on root span published (#7969)
Browse files Browse the repository at this point in the history
  • Loading branch information
manuel-alvarez-alvarez authored Nov 19, 2024
1 parent 4dfa404 commit ff8ee85
Show file tree
Hide file tree
Showing 11 changed files with 124 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.datadog.iast.taint.TaintedMap;
import com.datadog.iast.taint.TaintedObjects;
import datadog.trace.api.iast.IastContext;
import java.io.IOException;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
Expand All @@ -22,6 +23,9 @@ public TaintedObjects getTaintedObjects() {
return taintedObjects;
}

@Override
public void close() throws IOException {}

public static class Provider extends IastContext.Provider {

// (16384 * 4) buckets: approx 256K
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,24 @@

import com.datadog.iast.taint.TaintedObjects;
import datadog.trace.api.iast.IastContext;
import java.io.IOException;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.jetbrains.annotations.NotNull;

public class IastOptOutContext implements IastContext {

@Nonnull
@SuppressWarnings("unchecked")
@NotNull
@Override
public TaintedObjects getTaintedObjects() {
return TaintedObjects.NoOp.INSTANCE;
}

@Override
public void close() throws IOException {}

public static class Provider extends IastContext.Provider {

final IastContext optOutContext = new IastOptOutContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
import datadog.trace.api.iast.telemetry.IastMetricCollector.HasMetricCollector;
import datadog.trace.bootstrap.instrumentation.api.AgentSpan;
import datadog.trace.bootstrap.instrumentation.api.AgentTracer;
import java.io.IOException;
import java.util.Queue;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.function.Consumer;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

Expand All @@ -27,6 +29,7 @@ public class IastRequestContext implements IastContext, HasMetricCollector {
private final VulnerabilityBatch vulnerabilityBatch;
private final OverheadContext overheadContext;
private TaintedObjects taintedObjects;
@Nullable private Consumer<IastContext> release;
@Nullable private IastMetricCollector collector;
@Nullable private volatile String strictTransportSecurity;
@Nullable private volatile String xContentTypeOptions;
Expand Down Expand Up @@ -124,12 +127,20 @@ public void setTaintedObjects(@Nonnull final TaintedObjects taintedObjects) {
this.taintedObjects = taintedObjects;
}

@Override
public void close() throws IOException {
if (release != null) {
release.accept(this);
release = null;
}
}

public static class Provider extends IastContext.Provider {

// 16384 buckets: approx 64K
static final int MAP_SIZE = TaintedMap.DEFAULT_CAPACITY;

final Queue<TaintedObjects> pool =
private final Queue<TaintedObjects> pool =
new ArrayBlockingQueue<>(
Math.max(
Config.get().getIastMaxConcurrentRequests(), DEFAULT_IAST_MAX_CONCURRENT_REQUESTS));
Expand All @@ -154,19 +165,28 @@ public IastContext buildRequestContext() {
if (taintedObjects == null) {
taintedObjects = TaintedObjects.build(TaintedMap.build(MAP_SIZE));
}
return new IastRequestContext(taintedObjects);
final IastRequestContext ctx = new IastRequestContext(taintedObjects);
ctx.release = this::releaseRequestContext;
return ctx;
}

@SuppressWarnings("unchecked")
@Override
public void releaseRequestContext(@Nonnull final IastContext context) {
final TaintedObjects taintedObjects = context.getTaintedObjects();
final IastRequestContext iastCtx = (IastRequestContext) context;

// reset tainted objects map
final TaintedObjects taintedObjects = iastCtx.getTaintedObjects();
taintedObjects.clear();
// add the root instance to the pool
if (taintedObjects instanceof Wrapper) {
pool.offer(((Wrapper<TaintedObjects>) taintedObjects).unwrap());
} else {
pool.offer(taintedObjects);

// return to pool and update internal ref
final TaintedObjects unwrapped =
taintedObjects instanceof Wrapper
? ((Wrapper<TaintedObjects>) taintedObjects).unwrap()
: taintedObjects;
if (unwrapped != TaintedObjects.NoOp.INSTANCE) {
pool.offer(unwrapped);
iastCtx.setTaintedObjects(TaintedObjects.NoOp.INSTANCE);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.datadog.iast

import com.datadog.iast.model.Range
import com.datadog.iast.taint.TaintedObjects
import datadog.trace.api.Config
import datadog.trace.api.gateway.RequestContext
import datadog.trace.api.gateway.RequestContextSlot
import datadog.trace.bootstrap.instrumentation.api.AgentSpan
Expand Down Expand Up @@ -53,6 +54,15 @@ class IastRequestContextTest extends DDSpecification {

then:
1 * tracer.activeSpan() >> span
1 * span.getRequestContext() >> null
resolved == null

when:
resolved = provider.resolve()

then:
1 * tracer.activeSpan() >> span
1 * span.getRequestContext() >> reqCtx
resolved === initialCtx
}

Expand All @@ -72,5 +82,42 @@ class IastRequestContextTest extends DDSpecification {
then:
to.count() == 0
provider.pool.size() == 1

when:
final maxPoolSize = Config.get().getIastMaxConcurrentRequests()
final list = (1..2 * maxPoolSize).collect {
provider.buildRequestContext()
}

then:
provider.pool.size() == 0

when:
list.each { provider.releaseRequestContext(it) }

then:
provider.pool.size() == maxPoolSize
}

void 'ensure that the context releases all tainted objects on close'() {
setup:
final ctx = provider.buildRequestContext() as IastRequestContext

when:
ctx.withCloseable {
it.taintedObjects.taint(UUID.randomUUID(), [] as Range[])
}

then:
ctx.taintedObjects.count() == 0

when:
ctx.withCloseable {
it.taintedObjects.taint(UUID.randomUUID(), [] as Range[])
assert it.taintedObjects.count() == 0
}

then:
ctx.taintedObjects.count() == 0
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class IastSystemTest extends DDSpecification {

then:
1 * iastContext.getTaintedObjects()
1 * iastContext.setTaintedObjects(_)
1 * iastContext.getMetricCollector()
1 * traceSegment.setTagTop('_dd.iast.enabled', 1)
1 * iastContext.getxContentTypeOptions() >> 'nosniff'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import datadog.trace.agent.test.AgentTestRunner
import datadog.trace.agent.tooling.bytebuddy.iast.TaintableVisitor
import datadog.trace.api.gateway.CallbackProvider
import datadog.trace.api.gateway.Events
import datadog.trace.api.gateway.Flow
import datadog.trace.api.gateway.RequestContextSlot
import datadog.trace.api.iast.IastContext
import datadog.trace.api.iast.SourceTypes
Expand All @@ -15,7 +14,6 @@ import datadog.trace.bootstrap.instrumentation.api.AgentTracer
import datadog.trace.bootstrap.instrumentation.api.TagContext
import datadog.trace.core.DDSpan

import java.util.function.Supplier

class IastAgentTestRunner extends AgentTestRunner implements IastRequestContextPreparationTrait {
public static final EMPTY_SOURCE = new Source(SourceTypes.NONE, '', '')
Expand All @@ -40,25 +38,18 @@ class IastAgentTestRunner extends AgentTestRunner implements IastRequestContextP
IastContext.Provider.get().taintedObjects
}

protected TaintedObjectCollection getLocalTaintedObjectCollection() {
new TaintedObjectCollection(localTaintedObjects)
}

protected TaintedObjectCollection getTaintedObjectCollection(DDSpan span) {
final IastContext ctx = span.getRequestContext().getData(RequestContextSlot.IAST)
return new TaintedObjectCollection(ctx.getTaintedObjects())
}

protected DDSpan runUnderIastTrace(Closure cl) {
CallbackProvider iastCbp = TEST_TRACER.getCallbackProvider(RequestContextSlot.IAST)
Supplier<Flow<Object>> reqStartCb = iastCbp.getCallback(Events.EVENTS.requestStarted())
def reqStartCb = iastCbp.getCallback(Events.EVENTS.requestStarted())
def reqEndCb = iastCbp.getCallback(Events.EVENTS.requestEnded())

def iastCtx = reqStartCb.get().result
def ddctx = new TagContext().withRequestContextDataIast(iastCtx)
AgentSpan span = TEST_TRACER.startSpan("test", "test-iast-span", ddctx)
try {
AgentTracer.activateSpan(span).withCloseable cl
} finally {
reqEndCb.apply(span.requestContext, span)
span.finish()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,27 @@ import static datadog.trace.bootstrap.instrumentation.api.AgentTracer.get
trait IastRequestContextPreparationTrait {

static void iastSystemSetup(Closure reqEndAction = null) {
def ss = AgentTracer.get().getSubscriptionService(RequestContextSlot.IAST)
final tracer = AgentTracer.get()
def ss = tracer.getSubscriptionService(RequestContextSlot.IAST)
IastSystem.start(ss, new NoopOverheadController())

EventType<Supplier<Flow<Object>>> requestStarted = Events.get().requestStarted()
EventType<BiFunction<RequestContext, IGSpanInfo, Flow<Void>>> requestEnded =
Events.get().requestEnded()

// get original callbacks
CallbackProvider provider = AgentTracer.get().getCallbackProvider(RequestContextSlot.IAST)
CallbackProvider provider = tracer.getCallbackProvider(RequestContextSlot.IAST)
def origRequestStarted = provider.getCallback(requestStarted)
def origRequestEnded = provider.getCallback(requestEnded)

// wrap the original IG callbacks
ss.reset()
ss.registerCallback(requestStarted, new TaintedMapSaveStrongRefsRequestStarted(orig: origRequestStarted))
if (reqEndAction != null) {
ss.registerCallback(requestEnded, new TaintedMapSavingRequestEnded(
original: origRequestEnded, beforeAction: reqEndAction))
}
ss.registerCallback(
requestEnded,
reqEndAction == null
? origRequestEnded
: new TaintedMapSavingRequestEnded(original: origRequestEnded, beforeAction: reqEndAction))
}

static void iastSystemCleanup() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package iast

import com.datadog.iast.propagation.PropagationModuleImpl
import com.datadog.iast.test.IastAgentTestRunner
import com.datadog.iast.test.IastRequestTestRunner
import datadog.trace.api.iast.InstrumentationBridge
import datadog.trace.api.iast.SourceTypes
import org.apache.kafka.common.header.internals.RecordHeaders
Expand All @@ -16,7 +16,7 @@ import java.nio.ByteBuffer
import static org.hamcrest.CoreMatchers.instanceOf
import static org.hamcrest.core.IsEqual.equalTo

class KafkaIastDeserializerTest extends IastAgentTestRunner {
class KafkaIastDeserializerTest extends IastRequestTestRunner {

private static final int BUFF_OFFSET = 10

Expand All @@ -31,13 +31,13 @@ class KafkaIastDeserializerTest extends IastAgentTestRunner {
final deserializer = new StringDeserializer()

when:
final span = runUnderIastTrace {
runUnderIastTrace {
deserializer.configure([:], origin == SourceTypes.KAFKA_MESSAGE_KEY)
test.method.deserialize(deserializer, "test", payload)
}

then:
final to = getTaintedObjectCollection(span)
final to = finReqTaintedObjects
to.hasTaintedObject {
value('Hello World!')
range(0, 12, source(origin))
Expand All @@ -58,13 +58,13 @@ class KafkaIastDeserializerTest extends IastAgentTestRunner {
final deserializer = new ByteArrayDeserializer()

when:
final span = runUnderIastTrace {
runUnderIastTrace {
deserializer.configure([:], origin == SourceTypes.KAFKA_MESSAGE_KEY)
test.method.deserialize(deserializer, "test", payload)
}

then:
final to = getTaintedObjectCollection(span)
final to = finReqTaintedObjects
to.hasTaintedObject {
value(equalTo(payload))
range(0, Integer.MAX_VALUE, source(origin))
Expand All @@ -85,13 +85,13 @@ class KafkaIastDeserializerTest extends IastAgentTestRunner {
final deserializer = new ByteBufferDeserializer()

when:
final span = runUnderIastTrace {
runUnderIastTrace {
deserializer.configure([:], origin == SourceTypes.KAFKA_MESSAGE_KEY)
test.method.deserialize(deserializer, "test", payload)
}

then:
final to = getTaintedObjectCollection(span)
final to = finReqTaintedObjects
to.hasTaintedObject {
value(instanceOf(ByteBuffer))
range(0, Integer.MAX_VALUE, source(origin))
Expand All @@ -113,13 +113,13 @@ class KafkaIastDeserializerTest extends IastAgentTestRunner {
final deserializer = new JsonDeserializer(TestBean)

when:
final span = runUnderIastTrace {
runUnderIastTrace {
deserializer.configure([:], origin == SourceTypes.KAFKA_MESSAGE_KEY)
test.method.deserialize(deserializer, 'test', payload)
}

then:
final to = getTaintedObjectCollection(span)
final to = finReqTaintedObjects
to.hasTaintedObject {
value(instanceOf(TestBean))
range(0, Integer.MAX_VALUE, source(origin as byte))
Expand Down
Loading

0 comments on commit ff8ee85

Please sign in to comment.