1use crate::telemetry::Telemetry;
2use crate::trace;
3use std::any::TypeId;
4use std::collections::HashMap;
5use std::time::SystemTime;
6use tracing::span::{Attributes, Id, Record};
7use tracing::{Event, Subscriber};
8use tracing_subscriber::{layer::Context, registry, Layer};
9
10#[cfg(feature = "use_parking_lot")]
11use parking_lot::RwLock;
12#[cfg(not(feature = "use_parking_lot"))]
13use std::sync::RwLock;
14
15pub struct TelemetryLayer<Telemetry, SpanId, TraceId> {
18 service_name: &'static str,
19 pub(crate) telemetry: Telemetry,
20 pub(crate) trace_ctx_registry: TraceCtxRegistry<SpanId, TraceId>,
22}
23
24#[derive(PartialEq, Eq, Hash, Clone, Debug)]
25pub(crate) struct TraceCtx<SpanId, TraceId> {
26 pub(crate) parent_span: Option<SpanId>,
27 pub(crate) trace_id: TraceId,
28}
29
30pub(crate) struct TraceCtxRegistry<SpanId, TraceId> {
32 registry: RwLock<HashMap<Id, TraceCtx<SpanId, TraceId>>>,
33 promote_span_id: Box<dyn 'static + Send + Sync + Fn(Id) -> SpanId>,
34}
35
36impl<SpanId, TraceId> TraceCtxRegistry<SpanId, TraceId>
37where
38 SpanId: 'static + Clone + Send + Sync,
39 TraceId: 'static + Clone + Send + Sync,
40{
41 pub(crate) fn promote_span_id(&self, id: Id) -> SpanId {
42 (self.promote_span_id)(id)
43 }
44
45 pub(crate) fn record_trace_ctx(
46 &self,
47 trace_id: TraceId,
48 remote_parent_span: Option<SpanId>,
49 id: Id,
50 ) {
51 let trace_ctx = TraceCtx {
52 trace_id,
53 parent_span: remote_parent_span,
54 };
55
56 #[cfg(not(feature = "use_parking_lot"))]
57 let mut trace_ctx_registry = self.registry.write().expect("write lock!");
58 #[cfg(feature = "use_parking_lot")]
59 let mut trace_ctx_registry = self.registry.write();
60
61 trace_ctx_registry.insert(id, trace_ctx); }
63
64 pub(crate) fn eval_ctx<
65 'a,
66 X: 'a + registry::LookupSpan<'a>,
67 I: std::iter::Iterator<Item = registry::SpanRef<'a, X>>,
68 >(
69 &self,
70 iter: I,
71 ) -> Option<TraceCtx<SpanId, TraceId>> {
72 let mut path = Vec::new();
73
74 for span_ref in iter {
75 let mut write_guard = span_ref.extensions_mut();
76 match write_guard.get_mut::<LazyTraceCtx<SpanId, TraceId>>() {
77 None => {
78 #[cfg(not(feature = "use_parking_lot"))]
79 let trace_ctx_registry = self.registry.read().unwrap();
80 #[cfg(feature = "use_parking_lot")]
81 let trace_ctx_registry = self.registry.read();
82
83 match trace_ctx_registry.get(&span_ref.id()) {
84 None => {
85 drop(write_guard);
86 path.push(span_ref);
87 }
88 Some(local_trace_root) => {
89 write_guard.insert(LazyTraceCtx(local_trace_root.clone()));
90
91 let res = if path.is_empty() {
92 local_trace_root.clone()
93 } else {
94 TraceCtx {
95 trace_id: local_trace_root.trace_id.clone(),
96 parent_span: None,
97 }
98 };
99
100 for span_ref in path.into_iter() {
101 let mut write_guard = span_ref.extensions_mut();
102 write_guard.replace::<LazyTraceCtx<SpanId, TraceId>>(LazyTraceCtx(
103 TraceCtx {
104 trace_id: local_trace_root.trace_id.clone(),
105 parent_span: None,
106 },
107 ));
108 }
109 return Some(res);
110 }
111 }
112 }
113 Some(LazyTraceCtx(already_evaluated)) => {
114 let res = if path.is_empty() {
115 already_evaluated.clone()
116 } else {
117 TraceCtx {
118 trace_id: already_evaluated.trace_id.clone(),
119 parent_span: None,
120 }
121 };
122
123 for span_ref in path.into_iter() {
124 let mut write_guard = span_ref.extensions_mut();
125 write_guard.replace::<LazyTraceCtx<SpanId, TraceId>>(LazyTraceCtx(
126 TraceCtx {
127 trace_id: already_evaluated.trace_id.clone(),
128 parent_span: None,
129 },
130 ));
131 }
132 return Some(res);
133 }
134 }
135 }
136
137 None
138 }
139
140 pub(crate) fn new<F: 'static + Send + Sync + Fn(Id) -> SpanId>(f: F) -> Self {
141 let registry = RwLock::new(HashMap::new());
142 let promote_span_id = Box::new(f);
143
144 TraceCtxRegistry {
145 registry,
146 promote_span_id,
147 }
148 }
149}
150
151impl<T, SpanId, TraceId> TelemetryLayer<T, SpanId, TraceId>
152where
153 SpanId: 'static + Clone + Send + Sync,
154 TraceId: 'static + Clone + Send + Sync,
155{
156 pub fn new<F: 'static + Send + Sync + Fn(Id) -> SpanId>(
160 service_name: &'static str,
161 telemetry: T,
162 promote_span_id: F,
163 ) -> Self {
164 let trace_ctx_registry = TraceCtxRegistry::new(promote_span_id);
165
166 TelemetryLayer {
167 service_name,
168 telemetry,
169 trace_ctx_registry,
170 }
171 }
172}
173
174impl<S, TraceId, SpanId, V, T> Layer<S> for TelemetryLayer<T, SpanId, TraceId>
175where
176 S: Subscriber + for<'a> registry::LookupSpan<'a>,
177 TraceId: 'static + Clone + Eq + Send + Sync,
178 SpanId: 'static + Clone + Eq + Send + Sync,
179 V: 'static + tracing::field::Visit + Send + Sync,
180 T: 'static + Telemetry<Visitor = V, TraceId = TraceId, SpanId = SpanId>,
181{
182 fn on_new_span(&self, attrs: &Attributes, id: &Id, ctx: Context<S>) {
183 let span = ctx.span(id).expect("span data not found during new_span");
184 let mut extensions_mut = span.extensions_mut();
185 extensions_mut.insert(SpanInitAt::new());
186
187 let mut visitor: V = self.telemetry.mk_visitor();
188 attrs.record(&mut visitor);
189 extensions_mut.insert::<V>(visitor);
190 }
191
192 fn on_record(&self, id: &Id, values: &Record, ctx: Context<S>) {
193 let span = ctx.span(id).expect("span data not found during on_record");
194 let mut extensions_mut = span.extensions_mut();
195 let visitor: &mut V = extensions_mut
196 .get_mut()
197 .expect("fields extension not found during on_record");
198 values.record(visitor);
199 }
200
201 fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
202 let parent_id = if let Some(parent_id) = event.parent() {
203 Some(parent_id.clone())
205 } else if event.is_root() {
206 None
208 } else {
209 ctx.current_span().id().cloned()
211 };
212
213 match parent_id {
214 None => {} Some(parent_id) => {
216 let initialized_at = SystemTime::now();
217
218 let mut visitor = self.telemetry.mk_visitor();
219 event.record(&mut visitor);
220
221 let iter = itertools::unfold(Some(parent_id.clone()), |st| match st {
223 Some(target_id) => {
224 let res = ctx
225 .span(target_id)
226 .expect("span data not found during eval_ctx");
227 *st = res.parent().map(|x| x.id());
228 Some(res)
229 }
230 None => None,
231 });
232
233 if let Some(parent_trace_ctx) = self.trace_ctx_registry.eval_ctx(iter) {
235 let event = trace::Event {
236 trace_id: parent_trace_ctx.trace_id,
237 parent_id: Some(self.trace_ctx_registry.promote_span_id(parent_id)),
238 initialized_at,
239 meta: event.metadata(),
240 service_name: self.service_name,
241 values: visitor,
242 };
243
244 self.telemetry.report_event(event);
245 }
246 }
247 }
248 }
249
250 fn on_close(&self, id: Id, ctx: Context<'_, S>) {
251 let span = ctx.span(&id).expect("span data not found during on_close");
252
253 let iter = itertools::unfold(Some(id.clone()), |st| match st {
255 Some(target_id) => {
256 let res = ctx
257 .span(target_id)
258 .expect("span data not found during eval_ctx");
259 *st = res.parent().map(|x| x.id());
260 Some(res)
261 }
262 None => None,
263 });
264
265 if let Some(trace_ctx) = self.trace_ctx_registry.eval_ctx(iter) {
267 let mut extensions_mut = span.extensions_mut();
268 let visitor: V = extensions_mut
269 .remove()
270 .expect("should be present on all spans");
271 let SpanInitAt(initialized_at) = extensions_mut
272 .remove()
273 .expect("should be present on all spans");
274
275 let completed_at = SystemTime::now();
276
277 let parent_id = match trace_ctx.parent_span {
278 None => span
279 .parent()
280 .map(|parent_ref| self.trace_ctx_registry.promote_span_id(parent_ref.id())),
281 Some(parent_span) => Some(parent_span),
282 };
283
284 let span = trace::Span {
285 id: self.trace_ctx_registry.promote_span_id(id),
286 meta: span.metadata(),
287 parent_id,
288 initialized_at,
289 trace_id: trace_ctx.trace_id,
290 completed_at,
291 service_name: self.service_name,
292 values: visitor,
293 };
294
295 self.telemetry.report_span(span);
296 };
297 }
298
299 unsafe fn downcast_raw(&self, id: TypeId) -> Option<*const ()> {
304 match () {
308 _ if id == TypeId::of::<Self>() => Some(self as *const Self as *const ()),
309 _ if id == TypeId::of::<TraceCtxRegistry<SpanId, TraceId>>() => Some(
310 &self.trace_ctx_registry as *const TraceCtxRegistry<SpanId, TraceId> as *const (),
311 ),
312 _ => None,
313 }
314 }
315}
316
317struct LazyTraceCtx<SpanId, TraceId>(TraceCtx<SpanId, TraceId>);
319
320struct SpanInitAt(SystemTime);
321
322impl SpanInitAt {
323 fn new() -> Self {
324 let initialized_at = SystemTime::now();
325
326 Self(initialized_at)
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333 use crate::telemetry::test::{SpanId, TestTelemetry, TraceId};
334 use std::sync::Arc;
335 use std::sync::Mutex;
336 use std::time::Duration;
337 use tokio::runtime::Runtime;
338 use tracing::instrument;
339 use tracing_subscriber::layer::Layer;
340
341 fn explicit_trace_id() -> TraceId {
342 135
343 }
344
345 fn explicit_parent_span_id() -> SpanId {
346 Id::from_u64(246)
347 }
348
349 #[test]
350 fn test_instrument() {
351 with_test_scenario_runner(|| {
352 #[instrument]
353 fn f(ns: Vec<u64>) {
354 trace::register_dist_tracing_root(
355 explicit_trace_id(),
356 Some(explicit_parent_span_id()),
357 )
358 .unwrap();
359 for n in ns {
360 g(format!("{}", n));
361 }
362 }
363
364 #[instrument]
365 fn g(_s: String) {
366 let use_of_reserved_word = "duration-value";
367 tracing::event!(
368 tracing::Level::INFO,
369 duration_ms = use_of_reserved_word,
370 foo = "bar"
371 );
372
373 assert_eq!(
374 trace::current_dist_trace_ctx::<SpanId, TraceId>()
375 .map(|x| x.0)
376 .unwrap(),
377 explicit_trace_id(),
378 );
379 }
380
381 f(vec![1, 2, 3]);
382 });
383 }
384
385 #[test]
387 fn test_async_instrument() {
388 with_test_scenario_runner(|| {
389 #[instrument]
390 async fn f(ns: Vec<u64>) {
391 trace::register_dist_tracing_root(
392 explicit_trace_id(),
393 Some(explicit_parent_span_id()),
394 )
395 .unwrap();
396 for n in ns {
397 g(format!("{}", n)).await;
398 }
399 }
400
401 #[instrument]
402 async fn g(s: String) {
403 tokio::time::delay_for(Duration::from_millis(100)).await;
405 let use_of_reserved_word = "duration-value";
406 tracing::event!(
407 tracing::Level::INFO,
408 duration_ms = use_of_reserved_word,
409 foo = "bar"
410 );
411
412 assert_eq!(
413 trace::current_dist_trace_ctx::<SpanId, TraceId>()
414 .map(|x| x.0)
415 .unwrap(),
416 explicit_trace_id(),
417 );
418 }
419
420 let mut rt = Runtime::new().unwrap();
421 rt.block_on(f(vec![1, 2, 3]));
422 });
423 }
424
425 fn with_test_scenario_runner<F>(f: F)
426 where
427 F: Fn(),
428 {
429 let spans = Arc::new(Mutex::new(Vec::new()));
430 let events = Arc::new(Mutex::new(Vec::new()));
431 let cap: TestTelemetry = TestTelemetry::new(spans.clone(), events.clone());
432 let layer = TelemetryLayer::new("test_svc_name", cap, |x| x);
433
434 let subscriber = layer.with_subscriber(registry::Registry::default());
435 tracing::subscriber::with_default(subscriber, f);
436
437 let spans = spans.lock().unwrap();
438 let events = events.lock().unwrap();
439
440 let root_span = &spans[3];
442 let child_spans = &spans[0..3];
443
444 let expected_trace_id = explicit_trace_id();
445
446 assert_eq!(root_span.parent_id, Some(explicit_parent_span_id()));
447 assert_eq!(root_span.trace_id, expected_trace_id);
448
449 for (span, event) in child_spans.iter().zip(events.iter()) {
450 assert_eq!(span.parent_id, Some(root_span.id.clone()));
452 assert_eq!(event.parent_id, Some(span.id.clone()));
453 assert_eq!(span.trace_id, explicit_trace_id());
454 assert_eq!(event.trace_id, explicit_trace_id());
455 }
456 }
457}