use opentelemetry::{
trace as otel,
trace::{
noop, SamplingDecision, SamplingResult, SpanBuilder, SpanContext, SpanId, SpanKind,
TraceContextExt, TraceFlags, TraceId, TraceState,
},
Context as OtelContext,
};
use opentelemetry_sdk::trace::{Tracer as SdkTracer, TracerProvider as SdkTracerProvider};
pub trait PreSampledTracer {
fn sampled_context(&self, data: &mut crate::OtelData) -> OtelContext;
fn new_trace_id(&self) -> otel::TraceId;
fn new_span_id(&self) -> otel::SpanId;
}
impl PreSampledTracer for noop::NoopTracer {
fn sampled_context(&self, data: &mut crate::OtelData) -> OtelContext {
data.parent_cx.clone()
}
fn new_trace_id(&self) -> otel::TraceId {
otel::TraceId::INVALID
}
fn new_span_id(&self) -> otel::SpanId {
otel::SpanId::INVALID
}
}
impl PreSampledTracer for SdkTracer {
fn sampled_context(&self, data: &mut crate::OtelData) -> OtelContext {
let Some(provider) = self.provider() else {
return OtelContext::new();
};
let parent_cx = &data.parent_cx;
let builder = &mut data.builder;
let (trace_id, parent_trace_flags) = current_trace_state(builder, parent_cx, &provider);
let (flags, trace_state) = if let Some(result) = &builder.sampling_result {
process_sampling_result(result, parent_trace_flags)
} else {
builder.sampling_result = Some(provider.config().sampler.should_sample(
Some(parent_cx),
trace_id,
&builder.name,
builder.span_kind.as_ref().unwrap_or(&SpanKind::Internal),
builder.attributes.as_deref().unwrap_or(&[]),
builder.links.as_deref().unwrap_or(&[]),
));
process_sampling_result(
builder.sampling_result.as_ref().unwrap(),
parent_trace_flags,
)
}
.unwrap_or_default();
let span_id = builder.span_id.unwrap_or(SpanId::INVALID);
let span_context = SpanContext::new(trace_id, span_id, flags, false, trace_state);
parent_cx.with_remote_span_context(span_context)
}
fn new_trace_id(&self) -> otel::TraceId {
self.provider()
.map(|provider| provider.config().id_generator.new_trace_id())
.unwrap_or(otel::TraceId::INVALID)
}
fn new_span_id(&self) -> otel::SpanId {
self.provider()
.map(|provider| provider.config().id_generator.new_span_id())
.unwrap_or(otel::SpanId::INVALID)
}
}
fn current_trace_state(
builder: &SpanBuilder,
parent_cx: &OtelContext,
provider: &SdkTracerProvider,
) -> (TraceId, TraceFlags) {
if parent_cx.has_active_span() {
let span = parent_cx.span();
let sc = span.span_context();
(sc.trace_id(), sc.trace_flags())
} else {
(
builder
.trace_id
.unwrap_or_else(|| provider.config().id_generator.new_trace_id()),
Default::default(),
)
}
}
fn process_sampling_result(
sampling_result: &SamplingResult,
trace_flags: TraceFlags,
) -> Option<(TraceFlags, TraceState)> {
match sampling_result {
SamplingResult {
decision: SamplingDecision::Drop,
..
} => None,
SamplingResult {
decision: SamplingDecision::RecordOnly,
trace_state,
..
} => Some((trace_flags & !TraceFlags::SAMPLED, trace_state.clone())),
SamplingResult {
decision: SamplingDecision::RecordAndSample,
trace_state,
..
} => Some((trace_flags | TraceFlags::SAMPLED, trace_state.clone())),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::OtelData;
use opentelemetry::trace::TracerProvider as _;
use opentelemetry_sdk::trace::{config, Sampler, TracerProvider};
#[test]
fn assigns_default_trace_id_if_missing() {
let provider = TracerProvider::default();
let tracer = provider.tracer("test");
let mut builder = SpanBuilder::from_name("empty".to_string());
builder.span_id = Some(SpanId::from(1u64));
builder.trace_id = None;
let parent_cx = OtelContext::new();
let cx = tracer.sampled_context(&mut OtelData { builder, parent_cx });
let span = cx.span();
let span_context = span.span_context();
assert!(span_context.is_valid());
}
#[rustfmt::skip]
fn sampler_data() -> Vec<(&'static str, Sampler, OtelContext, Option<SamplingResult>, bool)> {
vec![
("empty_parent_cx_always_on", Sampler::AlwaysOn, OtelContext::new(), None, true),
("empty_parent_cx_always_off", Sampler::AlwaysOff, OtelContext::new(), None, false),
("remote_parent_cx_always_on", Sampler::AlwaysOn, OtelContext::new().with_remote_span_context(span_context(TraceFlags::SAMPLED, true)), None, true),
("remote_parent_cx_always_off", Sampler::AlwaysOff, OtelContext::new().with_remote_span_context(span_context(TraceFlags::SAMPLED, true)), None, false),
("sampled_remote_parent_cx_parent_based", Sampler::ParentBased(Box::new(Sampler::AlwaysOff)), OtelContext::new().with_remote_span_context(span_context(TraceFlags::SAMPLED, true)), None, true),
("unsampled_remote_parent_cx_parent_based", Sampler::ParentBased(Box::new(Sampler::AlwaysOn)), OtelContext::new().with_remote_span_context(span_context(TraceFlags::default(), true)), None, false),
("previous_drop_result_always_on", Sampler::AlwaysOn, OtelContext::new(), Some(SamplingResult { decision: SamplingDecision::Drop, attributes: vec![], trace_state: Default::default() }), false),
("previous_record_and_sample_result_always_off", Sampler::AlwaysOff, OtelContext::new(), Some(SamplingResult { decision: SamplingDecision::RecordAndSample, attributes: vec![], trace_state: Default::default() }), true),
("previous_drop_result_always_on", Sampler::AlwaysOn, OtelContext::new(), Some(SamplingResult { decision: SamplingDecision::Drop, attributes: vec![], trace_state: Default::default() }), false),
("previous_record_and_sample_result_always_off", Sampler::AlwaysOff, OtelContext::new(), Some(SamplingResult { decision: SamplingDecision::RecordAndSample, attributes: vec![], trace_state: Default::default() }), true),
]
}
#[test]
fn sampled_context() {
for (name, sampler, parent_cx, previous_sampling_result, is_sampled) in sampler_data() {
let provider = TracerProvider::builder()
.with_config(config().with_sampler(sampler))
.build();
let tracer = provider.tracer("test");
let mut builder = SpanBuilder::from_name("parent".to_string());
builder.sampling_result = previous_sampling_result;
let sampled = tracer.sampled_context(&mut OtelData { builder, parent_cx });
assert_eq!(
sampled.span().span_context().is_sampled(),
is_sampled,
"{}",
name
)
}
}
fn span_context(trace_flags: TraceFlags, is_remote: bool) -> SpanContext {
SpanContext::new(
TraceId::from(1u128),
SpanId::from(1u64),
trace_flags,
is_remote,
Default::default(),
)
}
}