quil_rs/program/scheduling/
schedule.rs

1//! A Schedule represents a flattening of the [`DependencyGraph`] into a linear sequence of
2//! instructions, with each instruction assigned a start time and duration.
3
4use std::collections::HashMap;
5
6use itertools::Itertools;
7use petgraph::{
8    visit::{EdgeFiltered, Topo},
9    Direction,
10};
11
12use crate::{
13    instruction::{
14        AttributeValue, Capture, Delay, Instruction, Pulse, RawCapture, WaveformInvocation,
15    },
16    quil::Quil,
17    Program,
18};
19
20use super::{ExecutionDependency, ScheduledBasicBlock, ScheduledGraphNode};
21
22#[derive(Clone, Debug, Default, PartialEq, PartialOrd)]
23pub struct Seconds(pub f64);
24
25impl std::ops::Add<Seconds> for Seconds {
26    type Output = Seconds;
27
28    fn add(self, rhs: Seconds) -> Self::Output {
29        Self(rhs.0 + self.0)
30    }
31}
32
33impl std::ops::Sub<Seconds> for Seconds {
34    type Output = Seconds;
35
36    fn sub(self, rhs: Seconds) -> Self::Output {
37        Self(self.0 - rhs.0)
38    }
39}
40
41pub trait Zero: PartialEq + Sized {
42    fn zero() -> Self;
43
44    fn is_zero(&self) -> bool {
45        self == &Self::zero()
46    }
47}
48
49impl Zero for Seconds {
50    fn zero() -> Self {
51        Self(0.0)
52    }
53}
54
55#[derive(Clone, Debug, PartialEq)]
56pub struct Schedule<TimeUnit> {
57    items: Vec<ComputedScheduleItem<TimeUnit>>,
58    /// The total duration of the block. This is the end time of the schedule when it starts at `TimeUnit::zero()`
59    duration: TimeUnit,
60}
61
62impl<TimeUnit> Schedule<TimeUnit> {
63    pub fn duration(&self) -> &TimeUnit {
64        &self.duration
65    }
66
67    pub fn items(&self) -> &[ComputedScheduleItem<TimeUnit>] {
68        self.items.as_ref()
69    }
70
71    pub fn into_items(self) -> Vec<ComputedScheduleItem<TimeUnit>> {
72        self.items
73    }
74}
75
76impl<TimeUnit: Clone + PartialOrd + std::ops::Add<TimeUnit, Output = TimeUnit> + Zero>
77    From<Vec<ComputedScheduleItem<TimeUnit>>> for Schedule<TimeUnit>
78{
79    fn from(items: Vec<ComputedScheduleItem<TimeUnit>>) -> Self {
80        let duration = items
81            .iter()
82            .map(|item| item.time_span.start_time.clone() + item.time_span.duration.clone())
83            .fold(TimeUnit::zero(), |acc, el| if el > acc { el } else { acc });
84        Self { items, duration }
85    }
86}
87
88impl<TimeUnit: Zero> Default for Schedule<TimeUnit> {
89    fn default() -> Self {
90        Self {
91            items: Default::default(),
92            duration: TimeUnit::zero(),
93        }
94    }
95}
96
97pub type ScheduleSeconds = Schedule<Seconds>;
98
99#[derive(Clone, Debug, PartialEq)]
100pub struct ComputedScheduleItem<TimeUnit> {
101    pub time_span: TimeSpan<TimeUnit>,
102    pub instruction_index: usize,
103}
104
105#[derive(Debug, thiserror::Error)]
106pub enum ComputedScheduleError {
107    #[error("unknown duration for instruction {}", instruction.to_quil_or_debug())]
108    UnknownDuration { instruction: Instruction },
109
110    #[error("internal error: invalid dependency graph")]
111    InvalidDependencyGraph,
112}
113
114pub type ComputedScheduleResult<T> = Result<T, ComputedScheduleError>;
115
116/// Represents a span of time, for some unit of time
117#[derive(Clone, Debug, PartialEq)]
118pub struct TimeSpan<TimeUnit> {
119    /// The inclusive start time of the described item
120    pub start_time: TimeUnit,
121
122    /// The described item's continuous duration
123    pub duration: TimeUnit,
124}
125
126impl<TimeUnit> TimeSpan<TimeUnit> {
127    pub fn start_time(&self) -> &TimeUnit {
128        &self.start_time
129    }
130
131    pub fn duration(&self) -> &TimeUnit {
132        &self.duration
133    }
134}
135
136impl<TimeUnit: Clone + std::ops::Add<TimeUnit, Output = TimeUnit>> TimeSpan<TimeUnit> {
137    pub fn end(&self) -> TimeUnit {
138        self.start_time.clone() + self.duration.clone()
139    }
140}
141
142impl<
143        TimeUnit: Clone
144            + PartialOrd
145            + std::ops::Add<TimeUnit, Output = TimeUnit>
146            + std::ops::Sub<TimeUnit, Output = TimeUnit>,
147    > TimeSpan<TimeUnit>
148{
149    pub(crate) fn union(self, rhs: Self) -> Self {
150        let start_time = if rhs.start_time < self.start_time {
151            rhs.start_time.clone()
152        } else {
153            self.start_time.clone()
154        };
155
156        let self_end_time = self.start_time.clone() + self.duration;
157        let rhs_end_time = rhs.start_time + rhs.duration;
158        let end_time = if self_end_time < rhs_end_time {
159            rhs_end_time
160        } else {
161            self_end_time
162        };
163
164        Self {
165            duration: end_time - start_time.clone(),
166            start_time,
167        }
168    }
169}
170
171impl<'p> ScheduledBasicBlock<'p> {
172    /// Return the duration of a scheduled Quil instruction:
173    ///
174    /// * For PULSE and CAPTURE, this is the duration of the waveform at the frame's sample rate
175    /// * For DELAY and RAW-CAPTURE, it's the named duration
176    /// * For supporting instructions like SET-*, SHIFT-*, and FENCE, it's 0
177    ///
178    /// Return `None` for other instructions.
179    pub(crate) fn get_instruction_duration_seconds(
180        program: &Program,
181        instruction: &Instruction,
182    ) -> Option<Seconds> {
183        match instruction {
184            Instruction::Capture(Capture { waveform, .. })
185            | Instruction::Pulse(Pulse { waveform, .. }) => {
186                Self::get_waveform_duration_seconds(program, instruction, waveform)
187            }
188            Instruction::Delay(Delay { duration, .. })
189            | Instruction::RawCapture(RawCapture { duration, .. }) => {
190                duration.to_real().ok().map(Seconds)
191            }
192            Instruction::Fence(_)
193            | Instruction::SetFrequency(_)
194            | Instruction::SetPhase(_)
195            | Instruction::SetScale(_)
196            | Instruction::ShiftFrequency(_)
197            | Instruction::ShiftPhase(_)
198            | Instruction::SwapPhases(_) => Some(Seconds(0.0)),
199            _ => None,
200        }
201    }
202
203    /// Return the duration of a Quil waveform:
204    ///
205    /// If the waveform is defined in the program with `DEFWAVEFORM`, the duration is the sample count
206    /// divided by the sample rate.
207    ///
208    /// Otherwise, it's the `duration` parameter of the waveform invocation. This relies on the assumption that
209    /// all template waveforms in use have such a parameter in units of seconds.
210    fn get_waveform_duration_seconds(
211        program: &Program,
212        instruction: &Instruction,
213        WaveformInvocation { name, parameters }: &WaveformInvocation,
214    ) -> Option<Seconds> {
215        if let Some(definition) = program.waveforms.get(name) {
216            let sample_count = definition.matrix.len();
217            let common_sample_rate =
218                program
219                    .get_frames_for_instruction(instruction)
220                    .and_then(|frames| {
221                        frames
222                            .used
223                            .into_iter()
224                            .filter_map(|frame| {
225                                program
226                                    .frames
227                                    .get(frame)
228                                    .and_then(|frame_definition| {
229                                        frame_definition.get("SAMPLE-RATE")
230                                    })
231                                    .and_then(|sample_rate_expression| match sample_rate_expression
232                                    {
233                                        AttributeValue::String(_) => None,
234                                        AttributeValue::Expression(expression) => Some(expression),
235                                    })
236                                    .and_then(|expression| expression.to_real().ok())
237                            })
238                            .all_equal_value()
239                            .ok()
240                    });
241
242            common_sample_rate
243                .map(|sample_rate| sample_count as f64 / sample_rate)
244                .map(Seconds)
245        } else {
246            // Per the Quil spec, all waveform templates have a "duration"
247            // parameter, and "erf_square" also has "pad_left" and "pad_right".
248            // We explicitly choose to be more flexible here, and allow any
249            // built-in waveform templates to have "pad_*" parameters, as well
250            // as allow "erf_square" to omit them.
251            let parameter = |parameter_name| {
252                parameters
253                    .get(parameter_name)
254                    .and_then(|v| v.to_real().ok())
255                    .map(Seconds)
256            };
257            Some(
258                parameter("duration")?
259                    + parameter("pad_left").unwrap_or(Seconds::zero())
260                    + parameter("pad_right").unwrap_or(Seconds::zero()),
261            )
262        }
263    }
264
265    /// Compute the flattened schedule for this [`ScheduledBasicBlock`] in terms of seconds,
266    /// using a default built-in calculation for the duration of scheduled instructions.
267    pub fn as_schedule_seconds(
268        &self,
269        program: &Program,
270    ) -> ComputedScheduleResult<ScheduleSeconds> {
271        self.as_schedule(program, Self::get_instruction_duration_seconds)
272    }
273
274    /// Compute the flattened schedule for this [`ScheduledBasicBlock`] using a user-provided
275    /// closure for computation of instruction duration.
276    ///
277    /// Return an error if the schedule cannot be computed from the information provided.
278    pub fn as_schedule<
279        F,
280        TimeUnit: Clone + PartialOrd + std::ops::Add<TimeUnit, Output = TimeUnit> + Zero,
281    >(
282        &self,
283        program: &'p Program,
284        get_duration: F,
285    ) -> ComputedScheduleResult<Schedule<TimeUnit>>
286    where
287        F: Fn(&'p Program, &'p Instruction) -> Option<TimeUnit>,
288    {
289        let mut schedule = Schedule::default();
290        let mut end_time_by_instruction_index = HashMap::<usize, TimeUnit>::new();
291
292        let graph_filtered = EdgeFiltered::from_fn(&self.graph, |(_, _, dependencies)| {
293            dependencies.contains(&ExecutionDependency::Scheduled)
294        });
295        let mut topo = Topo::new(&graph_filtered);
296
297        while let Some(instruction_node) = topo.next(&graph_filtered) {
298            if let ScheduledGraphNode::InstructionIndex(index) = instruction_node {
299                let instruction = *self
300                    .basic_block()
301                    .instructions()
302                    .get(index)
303                    .ok_or_else(|| ComputedScheduleError::InvalidDependencyGraph)?;
304                let duration = get_duration(program, instruction).ok_or(
305                    ComputedScheduleError::UnknownDuration {
306                        instruction: instruction.clone(),
307                    },
308                )?;
309
310                let latest_previous_instruction_scheduler_end_time = self
311                    .graph
312                    .edges_directed(instruction_node, Direction::Incoming)
313                    .filter_map(|(source, _, dependencies)| {
314                        if dependencies.contains(&ExecutionDependency::Scheduled) {
315                            match source {
316                                ScheduledGraphNode::BlockStart => Ok(Some(TimeUnit::zero())),
317                                ScheduledGraphNode::InstructionIndex(previous_index) => {
318                                    end_time_by_instruction_index
319                                        .get(&previous_index)
320                                        .cloned()
321                                        .ok_or(ComputedScheduleError::InvalidDependencyGraph)
322                                        .map(Some)
323                                }
324                                ScheduledGraphNode::BlockEnd => unreachable!(),
325                            }
326                        } else {
327                            Ok(None)
328                        }
329                        .transpose()
330                    })
331                    .collect::<Result<Vec<TimeUnit>, _>>()?
332                    .into_iter()
333                    // this implementation allows us to require PartialOrd instead of Ord (required for `.max()`),
334                    // which is convenient for f64
335                    .fold(TimeUnit::zero(), |acc, el| if el > acc { el } else { acc });
336
337                let start_time = latest_previous_instruction_scheduler_end_time;
338                let end_time = start_time.clone() + duration.clone();
339                if schedule.duration < end_time {
340                    schedule.duration = end_time.clone();
341                }
342
343                end_time_by_instruction_index.insert(index, end_time);
344                schedule.items.push(ComputedScheduleItem {
345                    time_span: TimeSpan {
346                        start_time,
347                        duration,
348                    },
349                    instruction_index: index,
350                });
351            }
352        }
353
354        Ok(schedule)
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use core::panic;
361    use std::str::FromStr;
362
363    use crate::{instruction::InstructionHandler, program::scheduling::TimeSpan, Program};
364
365    #[rstest::rstest]
366    #[case("CAPTURE 0 \"a\" flat(duration: 1.0) ro", Some(1.0))]
367    #[case("DELAY 0 \"a\" 1.0", Some(1.0))]
368    #[case("FENCE", Some(0.0))]
369    #[case("PULSE 0 \"a\" flat(duration: 1.0)", Some(1.0))]
370    #[case("RAW-CAPTURE 0 \"a\" 1.0 ro", Some(1.0))]
371    #[case("RESET", None)]
372    #[case("SET-FREQUENCY 0 \"a\" 1.0", Some(0.0))]
373    #[case("SET-PHASE 0 \"a\" 1.0", Some(0.0))]
374    #[case("SET-SCALE 0 \"a\" 1.0", Some(0.0))]
375    #[case("SHIFT-FREQUENCY 0 \"a\" 1.0", Some(0.0))]
376    #[case("SHIFT-PHASE 0 \"a\" 1.0", Some(0.0))]
377    #[case("SWAP-PHASES 0 \"a\" 0 \"b\"", Some(0.0))]
378    fn instruction_duration_seconds(
379        #[case] input_program: &str,
380        #[case] expected_duration: Option<f64>,
381    ) {
382        let empty_program = Program::new();
383        let program = Program::from_str(input_program)
384            .map_err(|e| e.to_string())
385            .unwrap();
386        let instruction = program.into_instructions().remove(0);
387        let duration =
388            crate::program::scheduling::ScheduledBasicBlock::get_instruction_duration_seconds(
389                &empty_program,
390                &instruction,
391            );
392        assert_eq!(
393            expected_duration.map(crate::program::scheduling::Seconds),
394            duration
395        );
396    }
397
398    #[rstest::rstest]
399    #[case(
400        r#"FENCE
401FENCE
402FENCE
403"#,
404        Ok(vec![0.0, 0.0, 0.0])
405    )]
406    #[case(
407        r#"DEFFRAME 0 "a":
408    SAMPLE-RATE: 1e9
409PULSE 0 "a" flat(duration: 1.0)
410PULSE 0 "a" flat(duration: 1.0)
411PULSE 0 "a" flat(duration: 1.0)
412"#,
413        Ok(vec![0.0, 1.0, 2.0])
414    )]
415    #[case(
416        r#"DEFFRAME 0 "a":
417    SAMPLE-RATE: 1e9
418PULSE 0 "a" erf_square(duration: 1.0, pad_left: 0.2, pad_right: 0.3)
419PULSE 0 "a" erf_square(duration: 0.1, pad_left: 0.7, pad_right: 0.7)
420PULSE 0 "a" erf_square(duration: 0.5, pad_left: 0.6, pad_right: 0.4)
421FENCE
422"#,
423        Ok(vec![0.0, 1.5, 3.0, 4.5])
424    )]
425    #[case(
426        r#"DEFFRAME 0 "a":
427    SAMPLE-RATE: 1e9
428DEFFRAME 0 "b":
429    SAMPLE-RATE: 1e9
430NONBLOCKING PULSE 0 "a" flat(duration: 1.0)
431NONBLOCKING PULSE 0 "b" flat(duration: 10.0)
432FENCE
433PULSE 0 "a" flat(duration: 1.0)
434FENCE
435PULSE 0 "a" flat(duration: 1.0)
436"#,
437        Ok(vec![0.0, 0.0, 10.0, 10.0, 11.0, 11.0])
438    )]
439    #[case(
440        r#"DEFFRAME 0 "a":
441    SAMPLE-RATE: 1e9
442DEFFRAME 0 "b":
443    SAMPLE-RATE: 1e9
444DELAY 0 "a" 1.0
445SET-PHASE 0 "a" 1.0
446SHIFT-PHASE 0 "a" 1.0
447SWAP-PHASES 0 "a" 0 "b"
448SET-FREQUENCY 0 "a" 1.0
449SHIFT-FREQUENCY 0 "a" 1.0
450SET-SCALE 0 "a" 1.0
451FENCE
452PULSE 0 "a" flat(duration: 1.0)
453"#,
454        Ok(vec![0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
455    )]
456    #[case("RESET", Err(()))]
457    fn schedule_seconds(#[case] input_program: &str, #[case] expected_times: Result<Vec<f64>, ()>) {
458        let program: Program = input_program.parse().unwrap();
459        let block: crate::program::analysis::BasicBlock = (&program).try_into().unwrap();
460        let mut handler = InstructionHandler::default();
461        let scheduled_block =
462            crate::program::scheduling::ScheduledBasicBlock::build(block, &program, &mut handler)
463                .unwrap();
464        match (
465            scheduled_block.as_schedule_seconds(&program),
466            expected_times,
467        ) {
468            (Ok(schedule), Ok(expected_times)) => {
469                let times = schedule
470                    .items()
471                    .iter()
472                    .map(|item| item.time_span.start_time.0)
473                    .collect::<Vec<_>>();
474                assert_eq!(expected_times, times);
475            }
476            (Err(_), Err(_)) => {}
477            (Ok(schedule), Err(_)) => {
478                let times = schedule
479                    .items()
480                    .iter()
481                    .map(|item| item.time_span.start_time.0)
482                    .collect::<Vec<_>>();
483                panic!("expected error, got {:?}", times);
484            }
485            (Err(error), Ok(_)) => {
486                panic!("expected success, got error: {error}")
487            }
488        }
489    }
490
491    #[rstest::rstest]
492    #[case::identical((0, 10), (0, 10), (0, 10))]
493    #[case::adjacent((0, 1), (1, 1), (0, 2))]
494    #[case::disjoint((0, 10), (20, 10), (0, 30))]
495    #[case::disjoint_reverse((20, 10), (0, 10), (0, 30))]
496    fn time_span_union(
497        #[case] a: (usize, usize),
498        #[case] b: (usize, usize),
499        #[case] expected: (usize, usize),
500    ) {
501        let a = TimeSpan {
502            start_time: a.0,
503            duration: a.1,
504        };
505        let b = TimeSpan {
506            start_time: b.0,
507            duration: b.1,
508        };
509        let expected = TimeSpan {
510            start_time: expected.0,
511            duration: expected.1,
512        };
513        assert_eq!(expected, a.union(b));
514    }
515}