1use std::collections::HashMap;
5use std::ops::{Range, RangeFrom, RangeTo};
6
7use crate::{
8 chain::{Chain, Link},
9 collect::Chains,
10 node::State,
11 resource::{AccessFlags, Buffer, Image, Resource},
12 schedule::{Queue, QueueId, Schedule, SubmissionId},
13 Id,
14};
15
16#[derive(Clone, Debug, PartialEq, Eq, Hash)]
20struct Semaphore {
21 id: Id,
22 points: Range<SubmissionId>,
23}
24
25impl Semaphore {
26 fn new(id: Id, points: Range<SubmissionId>) -> Self {
27 Semaphore { id, points }
28 }
29}
30
31#[derive(Copy, Clone, Debug, PartialEq, Eq)]
34pub struct Signal<S>(S);
35
36impl<S> Signal<S> {
37 fn new(semaphore: S) -> Self {
41 Signal(semaphore)
42 }
43
44 pub fn semaphore(&self) -> &S {
46 &self.0
47 }
48}
49
50#[derive(Copy, Clone, Debug, PartialEq, Eq)]
53pub struct Wait<S>(S, rendy_core::hal::pso::PipelineStage);
54
55impl<S> Wait<S> {
56 fn new(semaphore: S, stages: rendy_core::hal::pso::PipelineStage) -> Self {
60 Wait(semaphore, stages)
61 }
62
63 pub fn semaphore(&self) -> &S {
65 &self.0
66 }
67
68 pub fn stage(&self) -> rendy_core::hal::pso::PipelineStage {
70 self.1
71 }
72}
73
74#[derive(Clone, Debug)]
76pub struct Barrier<R: Resource> {
77 pub families: Option<Range<rendy_core::hal::queue::QueueFamilyId>>,
79
80 pub states: Range<(R::Access, R::Layout, rendy_core::hal::pso::PipelineStage)>,
82}
83
84impl<R> Barrier<R>
85where
86 R: Resource,
87{
88 fn new(states: Range<State<R>>) -> Self {
89 Barrier {
90 families: None,
91 states: (
92 states.start.access,
93 states.start.layout,
94 states.start.stages,
95 )..(states.end.access, states.end.layout, states.end.stages),
96 }
97 }
98
99 fn transfer(
100 families: Range<rendy_core::hal::queue::QueueFamilyId>,
101 states: Range<(R::Access, R::Layout)>,
102 ) -> Self {
103 Barrier {
104 families: Some(families),
105 states: (
106 states.start.0,
107 states.start.1,
108 rendy_core::hal::pso::PipelineStage::TOP_OF_PIPE,
109 )
110 ..(
111 states.end.0,
112 states.end.1,
113 rendy_core::hal::pso::PipelineStage::BOTTOM_OF_PIPE,
114 ),
115 }
116 }
117
118 fn acquire(
119 families: Range<rendy_core::hal::queue::QueueFamilyId>,
120 left: RangeFrom<R::Layout>,
121 right: RangeTo<(R::Access, R::Layout)>,
122 ) -> Self {
123 Self::transfer(
124 families,
125 (R::Access::empty(), left.start)..(right.end.0, right.end.1),
126 )
127 }
128
129 fn release(
130 families: Range<rendy_core::hal::queue::QueueFamilyId>,
131 left: RangeFrom<(R::Access, R::Layout)>,
132 right: RangeTo<R::Layout>,
133 ) -> Self {
134 Self::transfer(
135 families,
136 (left.start.0, left.start.1)..(R::Access::empty(), right.end),
137 )
138 }
139}
140
141pub type Barriers<R> = HashMap<Id, Barrier<R>>;
143
144pub type BufferBarriers = Barriers<Buffer>;
146
147pub type ImageBarriers = Barriers<Image>;
149
150#[derive(Clone, Debug)]
152pub struct Guard {
153 pub buffers: BufferBarriers,
155
156 pub images: ImageBarriers,
158}
159
160impl Guard {
161 fn new() -> Self {
162 Guard {
163 buffers: HashMap::default(),
164 images: HashMap::default(),
165 }
166 }
167
168 fn pick<R: Resource>(&mut self) -> &mut Barriers<R> {
169 use std::any::Any;
170 let Guard {
171 ref mut buffers,
172 ref mut images,
173 } = *self;
174 Any::downcast_mut(buffers)
175 .or_else(move || Any::downcast_mut(images))
176 .expect("`R` should be `Buffer` or `Image`")
177 }
178}
179
180#[derive(Clone, Debug)]
182pub struct SyncData<S, W> {
183 pub wait: Vec<Wait<W>>,
185
186 pub acquire: Guard,
189
190 pub release: Guard,
193
194 pub signal: Vec<Signal<S>>,
196}
197
198impl<S, W> SyncData<S, W> {
199 fn new() -> Self {
200 SyncData {
201 wait: Vec::new(),
202 acquire: Guard::new(),
203 release: Guard::new(),
204 signal: Vec::new(),
205 }
206 }
207
208 fn convert_signal<F, T>(self, mut f: F) -> SyncData<T, W>
209 where
210 F: FnMut(S) -> T,
211 {
212 SyncData {
213 wait: self.wait,
214 acquire: Guard {
215 buffers: self.acquire.buffers,
216 images: self.acquire.images,
217 },
218 release: Guard {
219 buffers: self.release.buffers,
220 images: self.release.images,
221 },
222 signal: self
223 .signal
224 .into_iter()
225 .map(|Signal(semaphore)| Signal(f(semaphore)))
226 .collect(),
227 }
228 }
229
230 fn convert_wait<F, T>(self, mut f: F) -> SyncData<S, T>
231 where
232 F: FnMut(W) -> T,
233 {
234 SyncData {
235 wait: self
236 .wait
237 .into_iter()
238 .map(|Wait(semaphore, stage)| Wait(f(semaphore), stage))
239 .collect(),
240 acquire: Guard {
241 buffers: self.acquire.buffers,
242 images: self.acquire.images,
243 },
244 release: Guard {
245 buffers: self.release.buffers,
246 images: self.release.images,
247 },
248 signal: self.signal,
249 }
250 }
251}
252
253struct SyncTemp(HashMap<SubmissionId, SyncData<Semaphore, Semaphore>>);
254impl SyncTemp {
255 fn get_sync(&mut self, sid: SubmissionId) -> &mut SyncData<Semaphore, Semaphore> {
256 self.0.entry(sid).or_insert_with(|| SyncData::new())
257 }
258}
259
260pub fn sync<F, S, W>(chains: &Chains, mut new_semaphore: F) -> Schedule<SyncData<S, W>>
262where
263 F: FnMut() -> (S, W),
264{
265 let ref schedule = chains.schedule;
266 let ref buffers = chains.buffers;
267 let ref images = chains.images;
268
269 let mut sync = SyncTemp(HashMap::default());
270 for (&id, chain) in buffers {
271 sync_chain(id, chain, schedule, &mut sync);
272 }
273 for (&id, chain) in images {
274 sync_chain(id, chain, schedule, &mut sync);
275 }
276 if schedule.queue_count() > 1 {
277 optimize(schedule, &mut sync);
278 }
279
280 let mut result = Schedule::new();
281 let mut signals: HashMap<Semaphore, Option<S>> = HashMap::default();
282 let mut waits: HashMap<Semaphore, Option<W>> = HashMap::default();
283
284 for queue in schedule.iter().flat_map(|family| family.iter()) {
285 let mut new_queue = Queue::new(queue.id());
286 for submission in queue.iter() {
287 let sync = if let Some(sync) = sync.0.remove(&submission.id()) {
288 let sync = sync.convert_signal(|semaphore| match signals.get_mut(&semaphore) {
289 None => {
290 let (signal, wait) = new_semaphore();
291 let old = waits.insert(semaphore, Some(wait));
292 assert!(old.is_none());
293 signal
294 }
295 Some(signal) => signal.take().unwrap(),
296 });
297 let sync = sync.convert_wait(|semaphore| match waits.get_mut(&semaphore) {
298 None => {
299 let (signal, wait) = new_semaphore();
300 let old = signals.insert(semaphore, Some(signal));
301 assert!(old.is_none());
302 wait
303 }
304 Some(wait) => wait.take().unwrap(),
305 });
306 sync
307 } else {
308 SyncData::new()
309 };
310 new_queue.add_submission_checked(submission.set_sync(sync));
311 }
312 result.set_queue(new_queue);
313 }
314
315 debug_assert!(sync.0.is_empty());
316 debug_assert!(signals.values().all(|x| x.is_none()));
317 debug_assert!(waits.values().all(|x| x.is_none()));
318
319 result
320}
321
322fn latest<R, S>(link: &Link<R>, schedule: &Schedule<S>) -> SubmissionId
325where
326 R: Resource,
327{
328 let (_, sid) = link
329 .queues()
330 .map(|(qid, queue)| {
331 let sid = SubmissionId::new(qid, queue.last);
332 (schedule[sid].submit_order(), sid)
333 })
334 .max_by_key(|&(submit_order, sid)| (submit_order, sid.queue().index()))
335 .unwrap();
336 sid
337}
338
339fn earliest<R, S>(link: &Link<R>, schedule: &Schedule<S>) -> SubmissionId
340where
341 R: Resource,
342{
343 let (_, sid) = link
344 .queues()
345 .map(|(qid, queue)| {
346 let sid = SubmissionId::new(qid, queue.first);
347 (schedule[sid].submit_order(), sid)
348 })
349 .min_by_key(|&(submit_order, sid)| (submit_order, sid.queue().index()))
350 .unwrap();
351 sid
352}
353
354fn generate_semaphore_pair<R: Resource>(
355 sync: &mut SyncTemp,
356 id: Id,
357 link: &Link<R>,
358 range: Range<SubmissionId>,
359) {
360 if range.start.queue() != range.end.queue() {
361 let semaphore = Semaphore::new(id, range.clone());
362 sync.get_sync(range.start)
363 .signal
364 .push(Signal::new(semaphore.clone()));
365 sync.get_sync(range.end)
366 .wait
367 .push(Wait::new(semaphore, link.queue(range.end.queue()).stages));
368 }
369}
370
371fn sync_chain<R, S>(id: Id, chain: &Chain<R>, schedule: &Schedule<S>, sync: &mut SyncTemp)
372where
373 R: Resource,
374{
375 let uid = id.into();
376
377 let pairs = chain
378 .links()
379 .windows(2)
380 .map(|pair| (&pair[0], &pair[1]))
381 .chain(
382 chain
383 .links()
384 .first()
385 .and_then(|first| chain.links().last().map(move |last| (last, first))),
386 );
387
388 for (prev_link, link) in pairs {
389 log::trace!("Sync {:#?}:{:#?}", prev_link.access(), link.access());
390 if prev_link.family() == link.family() {
391 if prev_link.access().exclusive() && !link.access().exclusive() {
393 let signal_sid = latest(prev_link, schedule);
394
395 sync.get_sync(signal_sid)
397 .release
398 .pick::<R>()
399 .insert(id, Barrier::new(prev_link.state()..link.state()));
400
401 for (queue_id, queue) in link.queues() {
403 let head = SubmissionId::new(queue_id, queue.first);
404 generate_semaphore_pair(sync, uid, link, signal_sid..head);
405 }
406 } else {
407 let wait_sid = earliest(link, schedule);
408
409 for (queue_id, queue) in prev_link.queues() {
411 let tail = SubmissionId::new(queue_id, queue.last);
412 generate_semaphore_pair(sync, uid, link, tail..wait_sid);
413 }
414
415 sync.get_sync(wait_sid)
417 .acquire
418 .pick()
419 .insert(id, Barrier::new(prev_link.state()..link.state()));
420
421 if !link.access().exclusive() {
422 unimplemented!("This case is unimplemented");
423 }
424 }
425 } else {
426 let signal_sid = latest(prev_link, schedule);
427 let wait_sid = earliest(link, schedule);
428
429 if !prev_link.access().exclusive() {
430 unimplemented!("This case is unimplemented");
431 }
432
433 generate_semaphore_pair(sync, uid, link, signal_sid..wait_sid);
435
436 sync.get_sync(signal_sid).release.pick::<R>().insert(
438 id,
439 Barrier::release(
440 signal_sid.family()..wait_sid.family(),
441 (prev_link.access(), prev_link.layout())..,
442 ..link.layout(),
443 ),
444 );
445 sync.get_sync(wait_sid).acquire.pick::<R>().insert(
446 id,
447 Barrier::acquire(
448 signal_sid.family()..wait_sid.family(),
449 prev_link.layout()..,
450 ..(link.access(), link.layout()),
451 ),
452 );
453
454 if !link.access().exclusive() {
455 unimplemented!("This case is unimplemented");
456 }
457 }
458 }
459}
460
461fn optimize_submission(
462 sid: SubmissionId,
463 found: &mut HashMap<QueueId, usize>,
464 sync: &mut SyncTemp,
465) {
466 let mut to_remove = Vec::new();
467 if let Some(sync_data) = sync.0.get_mut(&sid) {
468 sync_data
469 .wait
470 .sort_unstable_by_key(|wait| (wait.stage(), wait.semaphore().points.end.index()));
471 sync_data.wait.retain(|wait| {
472 let start = wait.semaphore().points.start;
473 if let Some(synched_to) = found.get_mut(&start.queue()) {
474 if *synched_to >= start.index() {
475 to_remove.push(wait.semaphore().clone());
476 return false;
477 } else {
478 *synched_to = start.index();
479 return true;
480 }
481 }
482 found.insert(start.queue(), start.index());
483 true
484 });
485 } else {
486 return;
487 }
488
489 for semaphore in to_remove.drain(..) {
490 let ref mut signal = sync.0.get_mut(&semaphore.points.start).unwrap().signal;
492 let index = signal
493 .iter()
494 .position(|signal| signal.0 == semaphore)
495 .unwrap();
496 signal.swap_remove(index);
497 }
498}
499
500fn optimize<S>(schedule: &Schedule<S>, sync: &mut SyncTemp) {
501 for queue in schedule.iter().flat_map(|family| family.iter()) {
502 let mut found = HashMap::default();
503 for submission in queue.iter() {
504 optimize_submission(submission.id(), &mut found, sync);
505 }
506 }
507}