Skip to content

Commit

Permalink
Reverted instrumentation-specific code; switched to thread local
Browse files Browse the repository at this point in the history
  • Loading branch information
kr-igor committed Nov 15, 2024
1 parent 0e9f3b6 commit 6856dc5
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 152 deletions.
11 changes: 0 additions & 11 deletions dd-java-agent/instrumentation/spark-executor/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,13 @@ ext {
dependencies {
compileOnly group: 'org.apache.spark', name: 'spark-core_2.12', version: '2.4.0'
compileOnly group: 'org.apache.spark', name: 'spark-sql_2.12', version: '2.4.0'
compileOnly group: 'org.apache.spark', name:'spark-sql-kafka-0-10_2.12', version: "2.4.0"

baseTestImplementation group: 'org.apache.spark', name: "spark-core_2.12", version: "2.4.0"
baseTestImplementation group: 'org.apache.spark', name: "spark-sql_2.12", version: "2.4.0"
baseTestImplementation group: 'org.apache.spark', name: "spark-sql_2.12", version: "2.4.0"
baseTestImplementation group: 'org.apache.spark', name:'spark-sql-kafka-0-10_2.12', version: "2.4.0"
testImplementation group: 'org.apache.kafka', name: 'kafka_2.12', version: '2.4.0'
testImplementation group: 'org.apache.kafka', name: 'kafka-clients', version: '2.4.0'
testImplementation group: 'org.springframework.kafka', name: 'spring-kafka', version: '2.4.0.RELEASE'
testImplementation group: 'org.springframework.kafka', name: 'spring-kafka-test', version: '2.4.0.RELEASE'

latest212DepTestImplementation group: 'org.apache.spark', name: "spark-core_2.12", version: '3.+'
latest212DepTestImplementation group: 'org.apache.spark', name: "spark-sql_2.12", version: '3.+'
latest212DepTestImplementation group: 'org.apache.spark', name: "spark-sql_2.12", version: "3.+"
latest212DepTestImplementation group: 'org.apache.spark', name:'spark-sql-kafka-0-10_2.12', version: "2.4.0"

latest213DepTestImplementation group: 'org.apache.spark', name: "spark-core_2.13", version: '3.+'
latest213DepTestImplementation group: 'org.apache.spark', name: "spark-sql_2.13", version: '3.+'
latest212DepTestImplementation group: 'org.apache.spark', name: "spark-sql_2.13", version: "3.+"
latest212DepTestImplementation group: 'org.apache.spark', name:'spark-sql-kafka-0-10_2.13', version: "3.+"
}
Original file line number Diff line number Diff line change
@@ -1,38 +1,17 @@
import datadog.trace.agent.test.AgentTestRunner
import datadog.trace.bootstrap.instrumentation.api.Tags
import org.apache.kafka.clients.producer.ProducerRecord
import org.apache.spark.api.java.function.VoidFunction2
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.Row
import org.apache.spark.sql.RowFactory
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.streaming.Trigger
import org.apache.spark.sql.types.StructType
import org.junit.ClassRule
import org.springframework.kafka.core.DefaultKafkaProducerFactory
import org.springframework.kafka.test.EmbeddedKafkaBroker
import org.springframework.kafka.test.rule.EmbeddedKafkaRule
import org.springframework.kafka.test.utils.KafkaTestUtils
import spock.lang.Shared


class SparkExecutorTest extends AgentTestRunner {
static final SOURCE_TOPIC = "source"
static final SINK_TOPIC = "sink"

@Shared
@ClassRule
EmbeddedKafkaRule kafkaRule = new EmbeddedKafkaRule(1, false, 1, SOURCE_TOPIC, SINK_TOPIC)
EmbeddedKafkaBroker embeddedKafka = kafkaRule.embeddedKafka

@Override
void configurePreAgent() {
super.configurePreAgent()
injectSysConfig("dd.integration.spark-executor.enabled", "true")
injectSysConfig("dd.integration.spark.enabled", "true")
injectSysConfig("dd.integration.kafka.enabled", "true")
injectSysConfig("dd.data.streams.enabled", "true")
injectSysConfig("dd.trace.debug", "true")
}

private Dataset<Row> generateSampleDataframe(SparkSession spark) {
Expand All @@ -44,57 +23,6 @@ class SparkExecutorTest extends AgentTestRunner {
spark.createDataFrame(rows, structType)
}

def "test dsm service name override"() {
setup:
def sparkSession = SparkSession.builder()
.config("spark.master", "local[2]")
.config("spark.driver.bindAddress", "localhost")
// .config("spark.sql.shuffle.partitions", "2")
.appName("test-app")
.getOrCreate()

def producerProps = KafkaTestUtils.producerProps(embeddedKafka.getBrokersAsString())
def producer = new DefaultKafkaProducerFactory<Integer, String>(producerProps).createProducer()

when:
for (int i = 0; i < 100; i++) {
producer.send(new ProducerRecord<>(SOURCE_TOPIC, i, i.toString()))
}
producer.flush()

def df = sparkSession
.readStream()
.format("kafka")
.option("kafka.bootstrap.servers", embeddedKafka.getBrokersAsString())
.option("startingOffsets", "earliest")
.option("failOnDataLoss", "false")
.option("subscribe", SOURCE_TOPIC)
.load()

def query = df
.selectExpr("CAST(key AS STRING) as key", "CAST(value AS STRING) as value")
.writeStream()
.format("kafka")
.option("kafka.bootstrap.servers", embeddedKafka.getBrokersAsString())
.option("checkpointLocation", "/tmp/" + System.currentTimeMillis().toString())
.option("topic", SINK_TOPIC)
.trigger(Trigger.Once())
.foreachBatch(new VoidFunction2<Dataset<Row>, Long>() {
@Override
void call(Dataset<Row> rowDataset, Long aLong) throws Exception {
rowDataset.show()
rowDataset.write()
}
})
.start()

query.processAllAvailable()

then:
query.stop()
producer.close()
}

def "generate spark task run spans"() {
setup:
def sparkSession = SparkSession.builder()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,33 +1,15 @@
package datadog.trace.instrumentation.spark;

import datadog.trace.bootstrap.instrumentation.api.AgentSpan;
import datadog.trace.bootstrap.instrumentation.api.AgentTracer;
import datadog.trace.bootstrap.instrumentation.api.UTF8BytesString;
import datadog.trace.bootstrap.instrumentation.decorator.BaseDecorator;
import datadog.trace.util.MethodHandles;
import java.lang.invoke.MethodHandle;
import java.util.Properties;
import org.apache.spark.executor.Executor;
import org.apache.spark.executor.TaskMetrics;

public class SparkExecutorDecorator extends BaseDecorator {
public static final CharSequence SPARK_TASK = UTF8BytesString.create("spark.task");
public static final CharSequence SPARK = UTF8BytesString.create("spark");
public static SparkExecutorDecorator DECORATE = new SparkExecutorDecorator();
private final String propSparkAppName = "spark.app.name";
private static final String TASK_DESCRIPTION_CLASSNAME =
"org.apache.spark.scheduler.TaskDescription";
private static final MethodHandle propertiesField_mh = getFieldGetter();

private static MethodHandle getFieldGetter() {
try {
return new MethodHandles(Executor.class.getClassLoader())
.privateFieldGetter(TASK_DESCRIPTION_CLASSNAME, "properties");
} catch (Throwable ignored) {
// should be already logged
}
return null;
}

@Override
protected String[] instrumentationNames() {
Expand All @@ -44,29 +26,12 @@ protected CharSequence component() {
return SPARK;
}

public void onTaskStart(AgentSpan span, Executor.TaskRunner taskRunner, Object taskDescription) {
public void onTaskStart(AgentSpan span, Executor.TaskRunner taskRunner) {
span.setTag("task_id", taskRunner.taskId());
span.setTag("task_thread_name", taskRunner.threadName());

if (taskDescription != null && propertiesField_mh != null) {
try {
Properties props = (Properties) propertiesField_mh.invoke(taskDescription);
if (props != null) {
String appName = props.getProperty(propSparkAppName);
if (appName != null) {
AgentTracer.get()
.getDataStreamsMonitoring()
.setThreadServiceName(taskRunner.getThreadId(), appName);
}
}
} catch (Throwable ignored) {
}
}
}

public void onTaskEnd(AgentSpan span, Executor.TaskRunner taskRunner) {
AgentTracer.get().getDataStreamsMonitoring().clearThreadServiceName(taskRunner.getThreadId());

// task is set by spark in run() by deserializing the task binary coming from the driver
if (taskRunner.task() == null) {
return;
Expand All @@ -85,7 +50,7 @@ public void onTaskEnd(AgentSpan span, Executor.TaskRunner taskRunner) {
span.setTag("app_attempt_id", taskRunner.task().appAttemptId().get());
}
span.setTag(
"application_name", taskRunner.task().localProperties().getProperty(propSparkAppName));
"application_name", taskRunner.task().localProperties().getProperty("spark.app.name"));

TaskMetrics metrics = taskRunner.task().metrics();
span.setMetric("spark.executor_deserialize_time", metrics.executorDeserializeTime());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,11 @@ public void methodAdvice(MethodTransformer transformer) {

public static final class RunAdvice {
@Advice.OnMethodEnter(suppress = Throwable.class)
public static AgentScope enter(
@Advice.FieldValue("taskDescription") final Object taskDescription,
@Advice.This Executor.TaskRunner taskRunner) {
public static AgentScope enter(@Advice.This Executor.TaskRunner taskRunner) {
final AgentSpan span = startSpan("spark-executor", SPARK_TASK);

DECORATE.afterStart(span);
DECORATE.onTaskStart(span, taskRunner, taskDescription);
DECORATE.onTaskStart(span, taskRunner);

return activateSpan(span);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ public class DefaultDataStreamsMonitoring implements DataStreamsMonitoring, Even
private volatile boolean agentSupportsDataStreams = false;
private volatile boolean configSupportsDataStreams = false;
private final ConcurrentHashMap<String, SchemaSampler> schemaSamplers;
private static final ConcurrentHashMap<Long, String> threadServiceNames =
new ConcurrentHashMap<>();
private static final ThreadLocal<String> serviceNameOverride = new ThreadLocal<>();

public DefaultDataStreamsMonitoring(
Config config,
Expand Down Expand Up @@ -188,29 +187,28 @@ public void setProduceCheckpoint(String type, String target) {
}

@Override
public void setThreadServiceName(Long threadId, String serviceName) {
// setting service name to null == removing the value
public void setThreadServiceName(String serviceName) {
if (serviceName == null) {
clearThreadServiceName(threadId);
clearThreadServiceName();
return;
}

threadServiceNames.put(threadId, serviceName);
serviceNameOverride.set(serviceName);
}

@Override
public void clearThreadServiceName(Long threadId) {
threadServiceNames.remove(threadId);
public void clearThreadServiceName() {
serviceNameOverride.remove();
}

private static String getThreadServiceNameOverride() {
return threadServiceNames.getOrDefault(Thread.currentThread().getId(), null);
private static String getThreadServiceName() {
return serviceNameOverride.get();
}

@Override
public PathwayContext newPathwayContext() {
if (configSupportsDataStreams) {
return new DefaultPathwayContext(timeSource, hashOfKnownTags, getThreadServiceNameOverride());
return new DefaultPathwayContext(timeSource, hashOfKnownTags, getThreadServiceName());
} else {
return AgentTracer.NoopPathwayContext.INSTANCE;
}
Expand All @@ -219,7 +217,7 @@ public PathwayContext newPathwayContext() {
@Override
public HttpCodec.Extractor extractor(HttpCodec.Extractor delegate) {
return new DataStreamContextExtractor(
delegate, timeSource, traceConfigSupplier, hashOfKnownTags, getThreadServiceNameOverride());
delegate, timeSource, traceConfigSupplier, hashOfKnownTags, getThreadServiceName());
}

@Override
Expand All @@ -236,7 +234,7 @@ public void mergePathwayContextIntoSpan(AgentSpan span, DataStreamsContextCarrie
DataStreamsContextCarrierAdapter.INSTANCE,
this.timeSource,
this.hashOfKnownTags,
getThreadServiceNameOverride());
getThreadServiceName());
((DDSpan) span).context().mergePathwayContext(pathwayContext);
}
}
Expand All @@ -250,8 +248,7 @@ public void trackBacklog(LinkedHashMap<String, String> sortedTags, long value) {
}
tags.add(tag);
}
inbox.offer(
new Backlog(tags, value, timeSource.getCurrentTimeNanos(), getThreadServiceNameOverride()));
inbox.offer(new Backlog(tags, value, timeSource.getCurrentTimeNanos(), getThreadServiceName()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ class DataStreamsWritingTest extends DDCoreSpecification {
when:
def dataStreams = new DefaultDataStreamsMonitoring(fakeConfig, sharedCommObjects, timeSource, { traceConfig })
dataStreams.start()
dataStreams.setThreadServiceName(Thread.currentThread().getId(), serviceNameOverride)
dataStreams.setThreadServiceName(serviceNameOverride)
dataStreams.add(new StatsPoint([], 9, 0, 10, timeSource.currentTimeNanos, 0, 0, 0, serviceNameOverride))
dataStreams.trackBacklog(new LinkedHashMap<>(["partition": "1", "topic": "testTopic", "type": "kafka_produce"]), 130)
timeSource.advance(DEFAULT_BUCKET_DURATION_NANOS)
// force flush
dataStreams.report()
dataStreams.close()
dataStreams.clearThreadServiceName(Thread.currentThread().getId())
dataStreams.clearThreadServiceName()
then:
conditions.eventually {
assert requestBodies.size() == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,12 @@ void setCheckpoint(

/**
* setServiceNameOverride is used override service name for all DataStreams payloads produced
* within given thread
* within Thread.currentThread()
*
* @param threadId thread Id
* @param serviceName new service name to use for DSM checkpoints.
*/
void setThreadServiceName(Long threadId, String serviceName);
void setThreadServiceName(String serviceName);

/**
* clearThreadServiceName clears up threadId -> Service name mapping
*
* @param threadId thread Id
*/
void clearThreadServiceName(Long threadId);
/** clearThreadServiceName clears up service name override for Thread.currentThread() */
void clearThreadServiceName();
}
Original file line number Diff line number Diff line change
Expand Up @@ -1135,10 +1135,10 @@ public Schema getSchema(String schemaName, SchemaIterator iterator) {
public void setProduceCheckpoint(String type, String target) {}

@Override
public void setThreadServiceName(Long threadId, String serviceName) {}
public void setThreadServiceName(String serviceName) {}

@Override
public void clearThreadServiceName(Long threadId) {}
public void clearThreadServiceName() {}

@Override
public void setConsumeCheckpoint(
Expand Down

0 comments on commit 6856dc5

Please sign in to comment.