spitfire_input/
lib.rs

1#[cfg(not(target_arch = "wasm32"))]
2use glutin::event::{ElementState, MouseButton, MouseScrollDelta, VirtualKeyCode, WindowEvent};
3use std::{
4    borrow::Cow,
5    cmp::Ordering,
6    collections::HashMap,
7    sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard},
8};
9use typid::ID;
10#[cfg(target_arch = "wasm32")]
11use winit::event::{ElementState, MouseButton, MouseScrollDelta, VirtualKeyCode, WindowEvent};
12
13#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
14pub enum InputConsume {
15    #[default]
16    None,
17    Hit,
18    All,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
22pub enum VirtualAction {
23    KeyButton(VirtualKeyCode),
24    MouseButton(MouseButton),
25    Axis(u32),
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub enum VirtualAxis {
30    KeyButton(VirtualKeyCode),
31    MousePositionX,
32    MousePositionY,
33    MouseWheelX,
34    MouseWheelY,
35    MouseButton(MouseButton),
36    Axis(u32),
37}
38
39#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
40pub enum InputAction {
41    #[default]
42    Idle,
43    Pressed,
44    Hold,
45    Released,
46}
47
48impl InputAction {
49    pub fn change(self, hold: bool) -> Self {
50        match (self, hold) {
51            (Self::Idle, true) | (Self::Released, true) => Self::Pressed,
52            (Self::Pressed, true) => Self::Hold,
53            (Self::Pressed, false) | (Self::Hold, false) => Self::Released,
54            (Self::Released, false) => Self::Idle,
55            _ => self,
56        }
57    }
58
59    pub fn update(self) -> Self {
60        match self {
61            Self::Pressed => Self::Hold,
62            Self::Released => Self::Idle,
63            _ => self,
64        }
65    }
66
67    pub fn is_idle(self) -> bool {
68        matches!(self, Self::Idle)
69    }
70
71    pub fn is_pressed(self) -> bool {
72        matches!(self, Self::Pressed)
73    }
74
75    pub fn is_hold(self) -> bool {
76        matches!(self, Self::Hold)
77    }
78
79    pub fn is_released(self) -> bool {
80        matches!(self, Self::Released)
81    }
82
83    pub fn is_up(self) -> bool {
84        matches!(self, Self::Idle | Self::Released)
85    }
86
87    pub fn is_down(self) -> bool {
88        matches!(self, Self::Pressed | Self::Hold)
89    }
90
91    pub fn is_changing(self) -> bool {
92        matches!(self, Self::Pressed | Self::Released)
93    }
94
95    pub fn is_continuing(self) -> bool {
96        matches!(self, Self::Idle | Self::Hold)
97    }
98
99    pub fn to_scalar(self, falsy: f32, truthy: f32) -> f32 {
100        if self.is_down() {
101            truthy
102        } else {
103            falsy
104        }
105    }
106}
107
108#[derive(Debug, Default, Clone, Copy, PartialEq)]
109pub struct InputAxis(pub f32);
110
111impl InputAxis {
112    pub fn threshold(self, value: f32) -> bool {
113        self.0 >= value
114    }
115}
116
117#[derive(Debug, Default, Clone)]
118pub struct InputRef<T: Default + Clone>(Arc<RwLock<T>>);
119
120impl<T: Default + Clone> InputRef<T> {
121    pub fn new(data: T) -> Self {
122        Self(Arc::new(RwLock::new(data)))
123    }
124
125    pub fn read(&self) -> Option<RwLockReadGuard<T>> {
126        self.0.read().ok()
127    }
128
129    pub fn write(&self) -> Option<RwLockWriteGuard<T>> {
130        self.0.write().ok()
131    }
132
133    pub fn get(&self) -> T {
134        self.read().map(|value| value.clone()).unwrap_or_default()
135    }
136
137    pub fn set(&self, value: T) {
138        if let Some(mut data) = self.write() {
139            *data = value;
140        }
141    }
142}
143
144pub type InputActionRef = InputRef<InputAction>;
145pub type InputAxisRef = InputRef<InputAxis>;
146pub type InputCharactersRef = InputRef<InputCharacters>;
147pub type InputMappingRef = InputRef<InputMapping>;
148
149#[derive(Debug, Default, Clone)]
150pub enum InputActionOrAxisRef {
151    #[default]
152    None,
153    Action(InputActionRef),
154    Axis(InputAxisRef),
155}
156
157impl InputActionOrAxisRef {
158    pub fn is_none(&self) -> bool {
159        matches!(self, Self::None)
160    }
161
162    pub fn is_some(&self) -> bool {
163        !self.is_none()
164    }
165
166    pub fn get_scalar(&self, falsy: f32, truthy: f32) -> f32 {
167        match self {
168            Self::None => falsy,
169            Self::Action(action) => action.get().to_scalar(falsy, truthy),
170            Self::Axis(axis) => axis.get().0,
171        }
172    }
173
174    pub fn threshold(&self, value: f32) -> bool {
175        match self {
176            Self::None => false,
177            Self::Action(action) => action.get().is_down(),
178            Self::Axis(axis) => axis.get().threshold(value),
179        }
180    }
181}
182
183impl From<InputActionRef> for InputActionOrAxisRef {
184    fn from(value: InputActionRef) -> Self {
185        Self::Action(value)
186    }
187}
188
189impl From<InputAxisRef> for InputActionOrAxisRef {
190    fn from(value: InputAxisRef) -> Self {
191        Self::Axis(value)
192    }
193}
194
195pub struct InputCombinator<T> {
196    mapper: Box<dyn Fn() -> T>,
197}
198
199impl<T: Default> Default for InputCombinator<T> {
200    fn default() -> Self {
201        Self::new(|| T::default())
202    }
203}
204
205impl<T> InputCombinator<T> {
206    pub fn new(mapper: impl Fn() -> T + 'static) -> Self {
207        Self {
208            mapper: Box::new(mapper),
209        }
210    }
211
212    pub fn get(&self) -> T {
213        (self.mapper)()
214    }
215}
216
217#[derive(Default)]
218pub struct CardinalInputCombinator(InputCombinator<[f32; 2]>);
219
220impl CardinalInputCombinator {
221    pub fn new(
222        left: impl Into<InputActionOrAxisRef>,
223        right: impl Into<InputActionOrAxisRef>,
224        up: impl Into<InputActionOrAxisRef>,
225        down: impl Into<InputActionOrAxisRef>,
226    ) -> Self {
227        let left = left.into();
228        let right = right.into();
229        let up = up.into();
230        let down = down.into();
231        Self(InputCombinator::new(move || {
232            let left = left.get_scalar(0.0, -1.0);
233            let right = right.get_scalar(0.0, 1.0);
234            let up = up.get_scalar(0.0, -1.0);
235            let down = down.get_scalar(0.0, 1.0);
236            [left + right, up + down]
237        }))
238    }
239
240    pub fn get(&self) -> [f32; 2] {
241        self.0.get()
242    }
243}
244
245#[derive(Default)]
246pub struct DualInputCombinator(InputCombinator<f32>);
247
248impl DualInputCombinator {
249    pub fn new(
250        negative: impl Into<InputActionOrAxisRef>,
251        positive: impl Into<InputActionOrAxisRef>,
252    ) -> Self {
253        let negative = negative.into();
254        let positive = positive.into();
255        Self(InputCombinator::new(move || {
256            let negative = negative.get_scalar(0.0, -1.0);
257            let positive = positive.get_scalar(0.0, 1.0);
258            negative + positive
259        }))
260    }
261
262    pub fn get(&self) -> f32 {
263        self.0.get()
264    }
265}
266
267pub struct ArrayInputCombinator<const N: usize>(InputCombinator<[f32; N]>);
268
269impl<const N: usize> Default for ArrayInputCombinator<N> {
270    fn default() -> Self {
271        Self(InputCombinator::new(|| {
272            std::array::from_fn(|_| Default::default())
273        }))
274    }
275}
276
277impl<const N: usize> ArrayInputCombinator<N> {
278    pub fn new(inputs: [impl Into<InputActionOrAxisRef>; N]) -> Self {
279        let mut items = std::array::from_fn::<InputActionOrAxisRef, N, _>(|_| Default::default());
280        for (index, input) in inputs.into_iter().enumerate() {
281            items[index] = input.into();
282        }
283        Self(InputCombinator::new(move || {
284            std::array::from_fn(|index| items[index].get_scalar(0.0, 1.0))
285        }))
286    }
287
288    pub fn get(&self) -> [f32; N] {
289        self.0.get()
290    }
291}
292
293#[derive(Debug, Default, Clone)]
294pub struct InputCharacters {
295    characters: String,
296}
297
298impl InputCharacters {
299    pub fn read(&self) -> &str {
300        &self.characters
301    }
302
303    pub fn write(&mut self) -> &mut String {
304        &mut self.characters
305    }
306
307    pub fn take(&mut self) -> String {
308        std::mem::take(&mut self.characters)
309    }
310}
311
312#[derive(Debug, Default, Clone)]
313pub struct InputMapping {
314    pub actions: HashMap<VirtualAction, InputActionRef>,
315    pub axes: HashMap<VirtualAxis, InputAxisRef>,
316    pub consume: InputConsume,
317    pub layer: isize,
318    pub name: Cow<'static, str>,
319}
320
321impl InputMapping {
322    pub fn action(mut self, id: VirtualAction, action: InputActionRef) -> Self {
323        self.actions.insert(id, action);
324        self
325    }
326
327    pub fn axis(mut self, id: VirtualAxis, axis: InputAxisRef) -> Self {
328        self.axes.insert(id, axis);
329        self
330    }
331
332    pub fn consume(mut self, consume: InputConsume) -> Self {
333        self.consume = consume;
334        self
335    }
336
337    pub fn layer(mut self, value: isize) -> Self {
338        self.layer = value;
339        self
340    }
341
342    pub fn name(mut self, value: impl Into<Cow<'static, str>>) -> Self {
343        self.name = value.into();
344        self
345    }
346}
347
348impl From<InputMapping> for InputMappingRef {
349    fn from(value: InputMapping) -> Self {
350        Self::new(value)
351    }
352}
353
354#[derive(Debug, Clone)]
355pub struct InputContext {
356    pub mouse_wheel_line_scale: f32,
357    /// [(id, mapping)]
358    mappings_stack: Vec<(ID<InputMapping>, InputMappingRef)>,
359    characters: InputCharactersRef,
360}
361
362impl Default for InputContext {
363    fn default() -> Self {
364        Self {
365            mouse_wheel_line_scale: Self::default_mouse_wheel_line_scale(),
366            mappings_stack: Default::default(),
367            characters: Default::default(),
368        }
369    }
370}
371
372impl InputContext {
373    fn default_mouse_wheel_line_scale() -> f32 {
374        10.0
375    }
376
377    pub fn push_mapping(&mut self, mapping: impl Into<InputMappingRef>) -> ID<InputMapping> {
378        let mapping = mapping.into();
379        let id = ID::default();
380        let layer = mapping.read().unwrap().layer;
381        let index = self
382            .mappings_stack
383            .binary_search_by(|(_, mapping)| {
384                mapping
385                    .read()
386                    .unwrap()
387                    .layer
388                    .cmp(&layer)
389                    .then(Ordering::Less)
390            })
391            .unwrap_or_else(|index| index);
392        self.mappings_stack.insert(index, (id, mapping));
393        id
394    }
395
396    pub fn pop_mapping(&mut self) -> Option<InputMappingRef> {
397        self.mappings_stack.pop().map(|(_, mapping)| mapping)
398    }
399
400    pub fn top_mapping(&self) -> Option<&InputMappingRef> {
401        self.mappings_stack.last().map(|(_, mapping)| mapping)
402    }
403
404    pub fn remove_mapping(&mut self, id: ID<InputMapping>) -> Option<InputMappingRef> {
405        self.mappings_stack
406            .iter()
407            .position(|(mid, _)| mid == &id)
408            .map(|index| self.mappings_stack.remove(index).1)
409    }
410
411    pub fn mapping(&self, id: ID<InputMapping>) -> Option<RwLockReadGuard<InputMapping>> {
412        self.mappings_stack
413            .iter()
414            .find(|(mid, _)| mid == &id)
415            .and_then(|(_, mapping)| mapping.read())
416    }
417
418    pub fn stack(&self) -> impl Iterator<Item = &InputMappingRef> {
419        self.mappings_stack.iter().map(|(_, mapping)| mapping)
420    }
421
422    pub fn characters(&self) -> InputCharactersRef {
423        self.characters.clone()
424    }
425
426    pub fn maintain(&mut self) {
427        for (_, mapping) in &mut self.mappings_stack {
428            if let Some(mut mapping) = mapping.write() {
429                for action in mapping.actions.values_mut() {
430                    if let Some(mut action) = action.write() {
431                        *action = action.update();
432                    }
433                }
434                for (id, axis) in &mut mapping.axes {
435                    if let VirtualAxis::MouseWheelX | VirtualAxis::MouseWheelY = id {
436                        if let Some(mut axis) = axis.write() {
437                            axis.0 = 0.0;
438                        }
439                    }
440                }
441            }
442        }
443    }
444
445    pub fn on_event(&mut self, event: &WindowEvent) {
446        match event {
447            WindowEvent::ReceivedCharacter(character) => {
448                if let Some(mut characters) = self.characters.write() {
449                    characters.characters.push(*character);
450                }
451            }
452            WindowEvent::KeyboardInput { input, .. } => {
453                if let Some(key) = input.virtual_keycode {
454                    for (_, mapping) in self.mappings_stack.iter().rev() {
455                        if let Some(mapping) = mapping.read() {
456                            let mut consume = mapping.consume == InputConsume::All;
457                            for (id, data) in &mapping.actions {
458                                if let VirtualAction::KeyButton(button) = id {
459                                    if *button == key {
460                                        if let Some(mut data) = data.write() {
461                                            *data =
462                                                data.change(input.state == ElementState::Pressed);
463                                            if mapping.consume == InputConsume::Hit {
464                                                consume = true;
465                                            }
466                                        }
467                                    }
468                                }
469                            }
470                            for (id, data) in &mapping.axes {
471                                if let VirtualAxis::KeyButton(button) = id {
472                                    if *button == key {
473                                        if let Some(mut data) = data.write() {
474                                            data.0 = if input.state == ElementState::Pressed {
475                                                1.0
476                                            } else {
477                                                0.0
478                                            };
479                                            if mapping.consume == InputConsume::Hit {
480                                                consume = true;
481                                            }
482                                        }
483                                    }
484                                }
485                            }
486                            if consume {
487                                break;
488                            }
489                        }
490                    }
491                }
492            }
493            WindowEvent::CursorMoved { position, .. } => {
494                for (_, mapping) in self.mappings_stack.iter().rev() {
495                    if let Some(mapping) = mapping.read() {
496                        let mut consume = mapping.consume == InputConsume::All;
497                        for (id, data) in &mapping.axes {
498                            match id {
499                                VirtualAxis::MousePositionX => {
500                                    if let Some(mut data) = data.write() {
501                                        data.0 = position.x as _;
502                                        if mapping.consume == InputConsume::Hit {
503                                            consume = true;
504                                        }
505                                    }
506                                }
507                                VirtualAxis::MousePositionY => {
508                                    if let Some(mut data) = data.write() {
509                                        data.0 = position.y as _;
510                                        if mapping.consume == InputConsume::Hit {
511                                            consume = true;
512                                        }
513                                    }
514                                }
515                                _ => {}
516                            }
517                        }
518                        if consume {
519                            break;
520                        }
521                    }
522                }
523            }
524            WindowEvent::MouseWheel { delta, .. } => {
525                for (_, mapping) in self.mappings_stack.iter().rev() {
526                    if let Some(mapping) = mapping.read() {
527                        let mut consume = mapping.consume == InputConsume::All;
528                        for (id, data) in &mapping.axes {
529                            match id {
530                                VirtualAxis::MouseWheelX => {
531                                    if let Some(mut data) = data.write() {
532                                        data.0 = match delta {
533                                            MouseScrollDelta::LineDelta(x, _) => *x,
534                                            MouseScrollDelta::PixelDelta(pos) => pos.x as _,
535                                        };
536                                        if mapping.consume == InputConsume::Hit {
537                                            consume = true;
538                                        }
539                                    }
540                                }
541                                VirtualAxis::MouseWheelY => {
542                                    if let Some(mut data) = data.write() {
543                                        data.0 = match delta {
544                                            MouseScrollDelta::LineDelta(_, y) => *y,
545                                            MouseScrollDelta::PixelDelta(pos) => pos.y as _,
546                                        };
547                                        if mapping.consume == InputConsume::Hit {
548                                            consume = true;
549                                        }
550                                    }
551                                }
552                                _ => {}
553                            }
554                        }
555                        if consume {
556                            break;
557                        }
558                    }
559                }
560            }
561            WindowEvent::MouseInput { state, button, .. } => {
562                for (_, mapping) in self.mappings_stack.iter().rev() {
563                    if let Some(mapping) = mapping.read() {
564                        let mut consume = mapping.consume == InputConsume::All;
565                        for (id, data) in &mapping.actions {
566                            if let VirtualAction::MouseButton(btn) = id {
567                                if button == btn {
568                                    if let Some(mut data) = data.write() {
569                                        *data = data.change(*state == ElementState::Pressed);
570                                        if mapping.consume == InputConsume::Hit {
571                                            consume = true;
572                                        }
573                                    }
574                                }
575                            }
576                        }
577                        for (id, data) in &mapping.axes {
578                            if let VirtualAxis::MouseButton(btn) = id {
579                                if button == btn {
580                                    if let Some(mut data) = data.write() {
581                                        data.0 = if *state == ElementState::Pressed {
582                                            1.0
583                                        } else {
584                                            0.0
585                                        };
586                                        if mapping.consume == InputConsume::Hit {
587                                            consume = true;
588                                        }
589                                    }
590                                }
591                            }
592                        }
593                        if consume {
594                            break;
595                        }
596                    }
597                }
598            }
599            WindowEvent::AxisMotion { axis, value, .. } => {
600                for (_, mapping) in self.mappings_stack.iter().rev() {
601                    if let Some(mapping) = mapping.read() {
602                        let mut consume = mapping.consume == InputConsume::All;
603                        for (id, data) in &mapping.actions {
604                            if let VirtualAction::Axis(index) = id {
605                                if axis == index {
606                                    if let Some(mut data) = data.write() {
607                                        *data = data.change(value.abs() > 0.5);
608                                        if mapping.consume == InputConsume::Hit {
609                                            consume = true;
610                                        }
611                                    }
612                                }
613                            }
614                        }
615                        for (id, data) in &mapping.axes {
616                            if let VirtualAxis::Axis(index) = id {
617                                if axis == index {
618                                    if let Some(mut data) = data.write() {
619                                        data.0 = *value as _;
620                                        if mapping.consume == InputConsume::Hit {
621                                            consume = true;
622                                        }
623                                    }
624                                }
625                            }
626                        }
627                        if consume {
628                            break;
629                        }
630                    }
631                }
632            }
633            _ => {}
634        }
635    }
636}
637
638#[cfg(test)]
639mod tests {
640    use crate::{InputContext, InputMapping};
641
642    #[test]
643    fn test_stack() {
644        let mut context = InputContext::default();
645        context.push_mapping(InputMapping::default().name("a").layer(0));
646        context.push_mapping(InputMapping::default().name("b").layer(0));
647        context.push_mapping(InputMapping::default().name("c").layer(0));
648        context.push_mapping(InputMapping::default().name("d").layer(-1));
649        context.push_mapping(InputMapping::default().name("e").layer(1));
650        context.push_mapping(InputMapping::default().name("f").layer(-1));
651        context.push_mapping(InputMapping::default().name("g").layer(1));
652        context.push_mapping(InputMapping::default().name("h").layer(-2));
653        context.push_mapping(InputMapping::default().name("i").layer(-2));
654        context.push_mapping(InputMapping::default().name("j").layer(2));
655        context.push_mapping(InputMapping::default().name("k").layer(2));
656
657        let provided = context
658            .stack()
659            .map(|mapping| {
660                let mapping = mapping.read().unwrap();
661                (mapping.name.as_ref().to_owned(), mapping.layer)
662            })
663            .collect::<Vec<_>>();
664        assert_eq!(
665            provided,
666            vec![
667                ("h".to_owned(), -2),
668                ("i".to_owned(), -2),
669                ("d".to_owned(), -1),
670                ("f".to_owned(), -1),
671                ("a".to_owned(), 0),
672                ("b".to_owned(), 0),
673                ("c".to_owned(), 0),
674                ("e".to_owned(), 1),
675                ("g".to_owned(), 1),
676                ("j".to_owned(), 2),
677                ("k".to_owned(), 2),
678            ]
679        );
680    }
681}