use crate::telemetry::Telemetry;
use crate::trace;
use std::any::TypeId;
use std::collections::HashMap;
use std::time::SystemTime;
use tracing::span::{Attributes, Id, Record};
use tracing::{Event, Subscriber};
use tracing_subscriber::{layer::Context, registry, Layer};
#[cfg(feature = "use_parking_lot")]
use parking_lot::RwLock;
#[cfg(not(feature = "use_parking_lot"))]
use std::sync::RwLock;
pub struct TelemetryLayer<Telemetry, SpanId, TraceId> {
service_name: &'static str,
pub(crate) telemetry: Telemetry,
pub(crate) trace_ctx_registry: TraceCtxRegistry<SpanId, TraceId>,
}
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
pub(crate) struct TraceCtx<SpanId, TraceId> {
pub(crate) parent_span: Option<SpanId>,
pub(crate) trace_id: TraceId,
}
pub(crate) struct TraceCtxRegistry<SpanId, TraceId> {
registry: RwLock<HashMap<Id, TraceCtx<SpanId, TraceId>>>,
promote_span_id: Box<dyn 'static + Send + Sync + Fn(Id) -> SpanId>,
}
impl<SpanId, TraceId> TraceCtxRegistry<SpanId, TraceId>
where
SpanId: 'static + Clone + Send + Sync,
TraceId: 'static + Clone + Send + Sync,
{
pub(crate) fn promote_span_id(&self, id: Id) -> SpanId {
(self.promote_span_id)(id)
}
pub(crate) fn record_trace_ctx(
&self,
trace_id: TraceId,
remote_parent_span: Option<SpanId>,
id: Id,
) {
let trace_ctx = TraceCtx {
trace_id,
parent_span: remote_parent_span,
};
#[cfg(not(feature = "use_parking_lot"))]
let mut trace_ctx_registry = self.registry.write().expect("write lock!");
#[cfg(feature = "use_parking_lot")]
let mut trace_ctx_registry = self.registry.write();
trace_ctx_registry.insert(id, trace_ctx); }
pub(crate) fn eval_ctx<
'a,
X: 'a + registry::LookupSpan<'a>,
I: std::iter::Iterator<Item = registry::SpanRef<'a, X>>,
>(
&self,
iter: I,
) -> Option<TraceCtx<SpanId, TraceId>> {
let mut path = Vec::new();
for span_ref in iter {
let mut write_guard = span_ref.extensions_mut();
match write_guard.get_mut::<LazyTraceCtx<SpanId, TraceId>>() {
None => {
#[cfg(not(feature = "use_parking_lot"))]
let trace_ctx_registry = self.registry.read().unwrap();
#[cfg(feature = "use_parking_lot")]
let trace_ctx_registry = self.registry.read();
match trace_ctx_registry.get(&span_ref.id()) {
None => {
drop(write_guard);
path.push(span_ref);
}
Some(local_trace_root) => {
write_guard.insert(LazyTraceCtx(local_trace_root.clone()));
let res = if path.is_empty() {
local_trace_root.clone()
} else {
TraceCtx {
trace_id: local_trace_root.trace_id.clone(),
parent_span: None,
}
};
for span_ref in path.into_iter() {
let mut write_guard = span_ref.extensions_mut();
write_guard.replace::<LazyTraceCtx<SpanId, TraceId>>(LazyTraceCtx(
TraceCtx {
trace_id: local_trace_root.trace_id.clone(),
parent_span: None,
},
));
}
return Some(res);
}
}
}
Some(LazyTraceCtx(already_evaluated)) => {
let res = if path.is_empty() {
already_evaluated.clone()
} else {
TraceCtx {
trace_id: already_evaluated.trace_id.clone(),
parent_span: None,
}
};
for span_ref in path.into_iter() {
let mut write_guard = span_ref.extensions_mut();
write_guard.replace::<LazyTraceCtx<SpanId, TraceId>>(LazyTraceCtx(
TraceCtx {
trace_id: already_evaluated.trace_id.clone(),
parent_span: None,
},
));
}
return Some(res);
}
}
}
None
}
pub(crate) fn new<F: 'static + Send + Sync + Fn(Id) -> SpanId>(f: F) -> Self {
let registry = RwLock::new(HashMap::new());
let promote_span_id = Box::new(f);
TraceCtxRegistry {
registry,
promote_span_id,
}
}
}
impl<T, SpanId, TraceId> TelemetryLayer<T, SpanId, TraceId>
where
SpanId: 'static + Clone + Send + Sync,
TraceId: 'static + Clone + Send + Sync,
{
pub fn new<F: 'static + Send + Sync + Fn(Id) -> SpanId>(
service_name: &'static str,
telemetry: T,
promote_span_id: F,
) -> Self {
let trace_ctx_registry = TraceCtxRegistry::new(promote_span_id);
TelemetryLayer {
service_name,
telemetry,
trace_ctx_registry,
}
}
}
impl<S, TraceId, SpanId, V, T> Layer<S> for TelemetryLayer<T, SpanId, TraceId>
where
S: Subscriber + for<'a> registry::LookupSpan<'a>,
TraceId: 'static + Clone + Eq + Send + Sync,
SpanId: 'static + Clone + Eq + Send + Sync,
V: 'static + tracing::field::Visit + Send + Sync,
T: 'static + Telemetry<Visitor = V, TraceId = TraceId, SpanId = SpanId>,
{
fn on_new_span(&self, attrs: &Attributes, id: &Id, ctx: Context<S>) {
let span = ctx.span(id).expect("span data not found during new_span");
let mut extensions_mut = span.extensions_mut();
extensions_mut.insert(SpanInitAt::new());
let mut visitor: V = self.telemetry.mk_visitor();
attrs.record(&mut visitor);
extensions_mut.insert::<V>(visitor);
}
fn on_record(&self, id: &Id, values: &Record, ctx: Context<S>) {
let span = ctx.span(id).expect("span data not found during on_record");
let mut extensions_mut = span.extensions_mut();
let visitor: &mut V = extensions_mut
.get_mut()
.expect("fields extension not found during on_record");
values.record(visitor);
}
fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
let parent_id = if let Some(parent_id) = event.parent() {
Some(parent_id.clone())
} else if event.is_root() {
None
} else {
ctx.current_span().id().cloned()
};
match parent_id {
None => {} Some(parent_id) => {
let initialized_at = SystemTime::now();
let mut visitor = self.telemetry.mk_visitor();
event.record(&mut visitor);
let iter = itertools::unfold(Some(parent_id.clone()), |st| match st {
Some(target_id) => {
let res = ctx
.span(target_id)
.expect("span data not found during eval_ctx");
*st = res.parent().map(|x| x.id());
Some(res)
}
None => None,
});
if let Some(parent_trace_ctx) = self.trace_ctx_registry.eval_ctx(iter) {
let event = trace::Event {
trace_id: parent_trace_ctx.trace_id,
parent_id: Some(self.trace_ctx_registry.promote_span_id(parent_id)),
initialized_at,
meta: event.metadata(),
service_name: self.service_name,
values: visitor,
};
self.telemetry.report_event(event);
}
}
}
}
fn on_close(&self, id: Id, ctx: Context<'_, S>) {
let span = ctx.span(&id).expect("span data not found during on_close");
let iter = itertools::unfold(Some(id.clone()), |st| match st {
Some(target_id) => {
let res = ctx
.span(target_id)
.expect("span data not found during eval_ctx");
*st = res.parent().map(|x| x.id());
Some(res)
}
None => None,
});
if let Some(trace_ctx) = self.trace_ctx_registry.eval_ctx(iter) {
let mut extensions_mut = span.extensions_mut();
let visitor: V = extensions_mut
.remove()
.expect("should be present on all spans");
let SpanInitAt(initialized_at) = extensions_mut
.remove()
.expect("should be present on all spans");
let completed_at = SystemTime::now();
let parent_id = match trace_ctx.parent_span {
None => span
.parent()
.map(|parent_ref| self.trace_ctx_registry.promote_span_id(parent_ref.id())),
Some(parent_span) => Some(parent_span),
};
let span = trace::Span {
id: self.trace_ctx_registry.promote_span_id(id),
meta: span.metadata(),
parent_id,
initialized_at,
trace_id: trace_ctx.trace_id,
completed_at,
service_name: self.service_name,
values: visitor,
};
self.telemetry.report_span(span);
};
}
unsafe fn downcast_raw(&self, id: TypeId) -> Option<*const ()> {
match () {
_ if id == TypeId::of::<Self>() => Some(self as *const Self as *const ()),
_ if id == TypeId::of::<TraceCtxRegistry<SpanId, TraceId>>() => Some(
&self.trace_ctx_registry as *const TraceCtxRegistry<SpanId, TraceId> as *const (),
),
_ => None,
}
}
}
struct LazyTraceCtx<SpanId, TraceId>(TraceCtx<SpanId, TraceId>);
struct SpanInitAt(SystemTime);
impl SpanInitAt {
fn new() -> Self {
let initialized_at = SystemTime::now();
Self(initialized_at)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::telemetry::test::{SpanId, TestTelemetry, TraceId};
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;
use tokio::runtime::Runtime;
use tracing::instrument;
use tracing_subscriber::layer::Layer;
fn explicit_trace_id() -> TraceId {
135
}
fn explicit_parent_span_id() -> SpanId {
Id::from_u64(246)
}
#[test]
fn test_instrument() {
with_test_scenario_runner(|| {
#[instrument]
fn f(ns: Vec<u64>) {
trace::register_dist_tracing_root(
explicit_trace_id(),
Some(explicit_parent_span_id()),
)
.unwrap();
for n in ns {
g(format!("{}", n));
}
}
#[instrument]
fn g(_s: String) {
let use_of_reserved_word = "duration-value";
tracing::event!(
tracing::Level::INFO,
duration_ms = use_of_reserved_word,
foo = "bar"
);
assert_eq!(
trace::current_dist_trace_ctx::<SpanId, TraceId>()
.map(|x| x.0)
.unwrap(),
explicit_trace_id(),
);
}
f(vec![1, 2, 3]);
});
}
#[test]
fn test_async_instrument() {
with_test_scenario_runner(|| {
#[instrument]
async fn f(ns: Vec<u64>) {
trace::register_dist_tracing_root(
explicit_trace_id(),
Some(explicit_parent_span_id()),
)
.unwrap();
for n in ns {
g(format!("{}", n)).await;
}
}
#[instrument]
async fn g(s: String) {
tokio::time::delay_for(Duration::from_millis(100)).await;
let use_of_reserved_word = "duration-value";
tracing::event!(
tracing::Level::INFO,
duration_ms = use_of_reserved_word,
foo = "bar"
);
assert_eq!(
trace::current_dist_trace_ctx::<SpanId, TraceId>()
.map(|x| x.0)
.unwrap(),
explicit_trace_id(),
);
}
let mut rt = Runtime::new().unwrap();
rt.block_on(f(vec![1, 2, 3]));
});
}
fn with_test_scenario_runner<F>(f: F)
where
F: Fn(),
{
let spans = Arc::new(Mutex::new(Vec::new()));
let events = Arc::new(Mutex::new(Vec::new()));
let cap: TestTelemetry = TestTelemetry::new(spans.clone(), events.clone());
let layer = TelemetryLayer::new("test_svc_name", cap, |x| x);
let subscriber = layer.with_subscriber(registry::Registry::default());
tracing::subscriber::with_default(subscriber, f);
let spans = spans.lock().unwrap();
let events = events.lock().unwrap();
let root_span = &spans[3];
let child_spans = &spans[0..3];
let expected_trace_id = explicit_trace_id();
assert_eq!(root_span.parent_id, Some(explicit_parent_span_id()));
assert_eq!(root_span.trace_id, expected_trace_id);
for (span, event) in child_spans.iter().zip(events.iter()) {
assert_eq!(span.parent_id, Some(root_span.id.clone()));
assert_eq!(event.parent_id, Some(span.id.clone()));
assert_eq!(span.trace_id, explicit_trace_id());
assert_eq!(event.trace_id, explicit_trace_id());
}
}
}