1#[cfg(not(target_arch = "wasm32"))]
2use glutin::event::{ElementState, MouseButton, MouseScrollDelta, VirtualKeyCode, WindowEvent};
3use std::{
4 borrow::Cow,
5 cmp::Ordering,
6 collections::HashMap,
7 sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard},
8};
9use typid::ID;
10#[cfg(target_arch = "wasm32")]
11use winit::event::{ElementState, MouseButton, MouseScrollDelta, VirtualKeyCode, WindowEvent};
12
13#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
14pub enum InputConsume {
15 #[default]
16 None,
17 Hit,
18 All,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
22pub enum VirtualAction {
23 KeyButton(VirtualKeyCode),
24 MouseButton(MouseButton),
25 Axis(u32),
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub enum VirtualAxis {
30 KeyButton(VirtualKeyCode),
31 MousePositionX,
32 MousePositionY,
33 MouseWheelX,
34 MouseWheelY,
35 MouseButton(MouseButton),
36 Axis(u32),
37}
38
39#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
40pub enum InputAction {
41 #[default]
42 Idle,
43 Pressed,
44 Hold,
45 Released,
46}
47
48impl InputAction {
49 pub fn change(self, hold: bool) -> Self {
50 match (self, hold) {
51 (Self::Idle, true) | (Self::Released, true) => Self::Pressed,
52 (Self::Pressed, true) => Self::Hold,
53 (Self::Pressed, false) | (Self::Hold, false) => Self::Released,
54 (Self::Released, false) => Self::Idle,
55 _ => self,
56 }
57 }
58
59 pub fn update(self) -> Self {
60 match self {
61 Self::Pressed => Self::Hold,
62 Self::Released => Self::Idle,
63 _ => self,
64 }
65 }
66
67 pub fn is_idle(self) -> bool {
68 matches!(self, Self::Idle)
69 }
70
71 pub fn is_pressed(self) -> bool {
72 matches!(self, Self::Pressed)
73 }
74
75 pub fn is_hold(self) -> bool {
76 matches!(self, Self::Hold)
77 }
78
79 pub fn is_released(self) -> bool {
80 matches!(self, Self::Released)
81 }
82
83 pub fn is_up(self) -> bool {
84 matches!(self, Self::Idle | Self::Released)
85 }
86
87 pub fn is_down(self) -> bool {
88 matches!(self, Self::Pressed | Self::Hold)
89 }
90
91 pub fn is_changing(self) -> bool {
92 matches!(self, Self::Pressed | Self::Released)
93 }
94
95 pub fn is_continuing(self) -> bool {
96 matches!(self, Self::Idle | Self::Hold)
97 }
98
99 pub fn to_scalar(self, falsy: f32, truthy: f32) -> f32 {
100 if self.is_down() {
101 truthy
102 } else {
103 falsy
104 }
105 }
106}
107
108#[derive(Debug, Default, Clone, Copy, PartialEq)]
109pub struct InputAxis(pub f32);
110
111impl InputAxis {
112 pub fn threshold(self, value: f32) -> bool {
113 self.0 >= value
114 }
115}
116
117#[derive(Debug, Default, Clone)]
118pub struct InputRef<T: Default + Clone>(Arc<RwLock<T>>);
119
120impl<T: Default + Clone> InputRef<T> {
121 pub fn new(data: T) -> Self {
122 Self(Arc::new(RwLock::new(data)))
123 }
124
125 pub fn read(&self) -> Option<RwLockReadGuard<T>> {
126 self.0.read().ok()
127 }
128
129 pub fn write(&self) -> Option<RwLockWriteGuard<T>> {
130 self.0.write().ok()
131 }
132
133 pub fn get(&self) -> T {
134 self.read().map(|value| value.clone()).unwrap_or_default()
135 }
136
137 pub fn set(&self, value: T) {
138 if let Some(mut data) = self.write() {
139 *data = value;
140 }
141 }
142}
143
144pub type InputActionRef = InputRef<InputAction>;
145pub type InputAxisRef = InputRef<InputAxis>;
146pub type InputCharactersRef = InputRef<InputCharacters>;
147pub type InputMappingRef = InputRef<InputMapping>;
148
149#[derive(Debug, Default, Clone)]
150pub enum InputActionOrAxisRef {
151 #[default]
152 None,
153 Action(InputActionRef),
154 Axis(InputAxisRef),
155}
156
157impl InputActionOrAxisRef {
158 pub fn is_none(&self) -> bool {
159 matches!(self, Self::None)
160 }
161
162 pub fn is_some(&self) -> bool {
163 !self.is_none()
164 }
165
166 pub fn get_scalar(&self, falsy: f32, truthy: f32) -> f32 {
167 match self {
168 Self::None => falsy,
169 Self::Action(action) => action.get().to_scalar(falsy, truthy),
170 Self::Axis(axis) => axis.get().0,
171 }
172 }
173
174 pub fn threshold(&self, value: f32) -> bool {
175 match self {
176 Self::None => false,
177 Self::Action(action) => action.get().is_down(),
178 Self::Axis(axis) => axis.get().threshold(value),
179 }
180 }
181}
182
183impl From<InputActionRef> for InputActionOrAxisRef {
184 fn from(value: InputActionRef) -> Self {
185 Self::Action(value)
186 }
187}
188
189impl From<InputAxisRef> for InputActionOrAxisRef {
190 fn from(value: InputAxisRef) -> Self {
191 Self::Axis(value)
192 }
193}
194
195pub struct InputCombinator<T> {
196 mapper: Box<dyn Fn() -> T>,
197}
198
199impl<T: Default> Default for InputCombinator<T> {
200 fn default() -> Self {
201 Self::new(|| T::default())
202 }
203}
204
205impl<T> InputCombinator<T> {
206 pub fn new(mapper: impl Fn() -> T + 'static) -> Self {
207 Self {
208 mapper: Box::new(mapper),
209 }
210 }
211
212 pub fn get(&self) -> T {
213 (self.mapper)()
214 }
215}
216
217#[derive(Default)]
218pub struct CardinalInputCombinator(InputCombinator<[f32; 2]>);
219
220impl CardinalInputCombinator {
221 pub fn new(
222 left: impl Into<InputActionOrAxisRef>,
223 right: impl Into<InputActionOrAxisRef>,
224 up: impl Into<InputActionOrAxisRef>,
225 down: impl Into<InputActionOrAxisRef>,
226 ) -> Self {
227 let left = left.into();
228 let right = right.into();
229 let up = up.into();
230 let down = down.into();
231 Self(InputCombinator::new(move || {
232 let left = left.get_scalar(0.0, -1.0);
233 let right = right.get_scalar(0.0, 1.0);
234 let up = up.get_scalar(0.0, -1.0);
235 let down = down.get_scalar(0.0, 1.0);
236 [left + right, up + down]
237 }))
238 }
239
240 pub fn get(&self) -> [f32; 2] {
241 self.0.get()
242 }
243}
244
245#[derive(Default)]
246pub struct DualInputCombinator(InputCombinator<f32>);
247
248impl DualInputCombinator {
249 pub fn new(
250 negative: impl Into<InputActionOrAxisRef>,
251 positive: impl Into<InputActionOrAxisRef>,
252 ) -> Self {
253 let negative = negative.into();
254 let positive = positive.into();
255 Self(InputCombinator::new(move || {
256 let negative = negative.get_scalar(0.0, -1.0);
257 let positive = positive.get_scalar(0.0, 1.0);
258 negative + positive
259 }))
260 }
261
262 pub fn get(&self) -> f32 {
263 self.0.get()
264 }
265}
266
267pub struct ArrayInputCombinator<const N: usize>(InputCombinator<[f32; N]>);
268
269impl<const N: usize> Default for ArrayInputCombinator<N> {
270 fn default() -> Self {
271 Self(InputCombinator::new(|| {
272 std::array::from_fn(|_| Default::default())
273 }))
274 }
275}
276
277impl<const N: usize> ArrayInputCombinator<N> {
278 pub fn new(inputs: [impl Into<InputActionOrAxisRef>; N]) -> Self {
279 let mut items = std::array::from_fn::<InputActionOrAxisRef, N, _>(|_| Default::default());
280 for (index, input) in inputs.into_iter().enumerate() {
281 items[index] = input.into();
282 }
283 Self(InputCombinator::new(move || {
284 std::array::from_fn(|index| items[index].get_scalar(0.0, 1.0))
285 }))
286 }
287
288 pub fn get(&self) -> [f32; N] {
289 self.0.get()
290 }
291}
292
293#[derive(Debug, Default, Clone)]
294pub struct InputCharacters {
295 characters: String,
296}
297
298impl InputCharacters {
299 pub fn read(&self) -> &str {
300 &self.characters
301 }
302
303 pub fn write(&mut self) -> &mut String {
304 &mut self.characters
305 }
306
307 pub fn take(&mut self) -> String {
308 std::mem::take(&mut self.characters)
309 }
310}
311
312#[derive(Debug, Default, Clone)]
313pub struct InputMapping {
314 pub actions: HashMap<VirtualAction, InputActionRef>,
315 pub axes: HashMap<VirtualAxis, InputAxisRef>,
316 pub consume: InputConsume,
317 pub layer: isize,
318 pub name: Cow<'static, str>,
319}
320
321impl InputMapping {
322 pub fn action(mut self, id: VirtualAction, action: InputActionRef) -> Self {
323 self.actions.insert(id, action);
324 self
325 }
326
327 pub fn axis(mut self, id: VirtualAxis, axis: InputAxisRef) -> Self {
328 self.axes.insert(id, axis);
329 self
330 }
331
332 pub fn consume(mut self, consume: InputConsume) -> Self {
333 self.consume = consume;
334 self
335 }
336
337 pub fn layer(mut self, value: isize) -> Self {
338 self.layer = value;
339 self
340 }
341
342 pub fn name(mut self, value: impl Into<Cow<'static, str>>) -> Self {
343 self.name = value.into();
344 self
345 }
346}
347
348impl From<InputMapping> for InputMappingRef {
349 fn from(value: InputMapping) -> Self {
350 Self::new(value)
351 }
352}
353
354#[derive(Debug, Clone)]
355pub struct InputContext {
356 pub mouse_wheel_line_scale: f32,
357 mappings_stack: Vec<(ID<InputMapping>, InputMappingRef)>,
359 characters: InputCharactersRef,
360}
361
362impl Default for InputContext {
363 fn default() -> Self {
364 Self {
365 mouse_wheel_line_scale: Self::default_mouse_wheel_line_scale(),
366 mappings_stack: Default::default(),
367 characters: Default::default(),
368 }
369 }
370}
371
372impl InputContext {
373 fn default_mouse_wheel_line_scale() -> f32 {
374 10.0
375 }
376
377 pub fn push_mapping(&mut self, mapping: impl Into<InputMappingRef>) -> ID<InputMapping> {
378 let mapping = mapping.into();
379 let id = ID::default();
380 let layer = mapping.read().unwrap().layer;
381 let index = self
382 .mappings_stack
383 .binary_search_by(|(_, mapping)| {
384 mapping
385 .read()
386 .unwrap()
387 .layer
388 .cmp(&layer)
389 .then(Ordering::Less)
390 })
391 .unwrap_or_else(|index| index);
392 self.mappings_stack.insert(index, (id, mapping));
393 id
394 }
395
396 pub fn pop_mapping(&mut self) -> Option<InputMappingRef> {
397 self.mappings_stack.pop().map(|(_, mapping)| mapping)
398 }
399
400 pub fn top_mapping(&self) -> Option<&InputMappingRef> {
401 self.mappings_stack.last().map(|(_, mapping)| mapping)
402 }
403
404 pub fn remove_mapping(&mut self, id: ID<InputMapping>) -> Option<InputMappingRef> {
405 self.mappings_stack
406 .iter()
407 .position(|(mid, _)| mid == &id)
408 .map(|index| self.mappings_stack.remove(index).1)
409 }
410
411 pub fn mapping(&self, id: ID<InputMapping>) -> Option<RwLockReadGuard<InputMapping>> {
412 self.mappings_stack
413 .iter()
414 .find(|(mid, _)| mid == &id)
415 .and_then(|(_, mapping)| mapping.read())
416 }
417
418 pub fn stack(&self) -> impl Iterator<Item = &InputMappingRef> {
419 self.mappings_stack.iter().map(|(_, mapping)| mapping)
420 }
421
422 pub fn characters(&self) -> InputCharactersRef {
423 self.characters.clone()
424 }
425
426 pub fn maintain(&mut self) {
427 for (_, mapping) in &mut self.mappings_stack {
428 if let Some(mut mapping) = mapping.write() {
429 for action in mapping.actions.values_mut() {
430 if let Some(mut action) = action.write() {
431 *action = action.update();
432 }
433 }
434 for (id, axis) in &mut mapping.axes {
435 if let VirtualAxis::MouseWheelX | VirtualAxis::MouseWheelY = id {
436 if let Some(mut axis) = axis.write() {
437 axis.0 = 0.0;
438 }
439 }
440 }
441 }
442 }
443 }
444
445 pub fn on_event(&mut self, event: &WindowEvent) {
446 match event {
447 WindowEvent::ReceivedCharacter(character) => {
448 if let Some(mut characters) = self.characters.write() {
449 characters.characters.push(*character);
450 }
451 }
452 WindowEvent::KeyboardInput { input, .. } => {
453 if let Some(key) = input.virtual_keycode {
454 for (_, mapping) in self.mappings_stack.iter().rev() {
455 if let Some(mapping) = mapping.read() {
456 let mut consume = mapping.consume == InputConsume::All;
457 for (id, data) in &mapping.actions {
458 if let VirtualAction::KeyButton(button) = id {
459 if *button == key {
460 if let Some(mut data) = data.write() {
461 *data =
462 data.change(input.state == ElementState::Pressed);
463 if mapping.consume == InputConsume::Hit {
464 consume = true;
465 }
466 }
467 }
468 }
469 }
470 for (id, data) in &mapping.axes {
471 if let VirtualAxis::KeyButton(button) = id {
472 if *button == key {
473 if let Some(mut data) = data.write() {
474 data.0 = if input.state == ElementState::Pressed {
475 1.0
476 } else {
477 0.0
478 };
479 if mapping.consume == InputConsume::Hit {
480 consume = true;
481 }
482 }
483 }
484 }
485 }
486 if consume {
487 break;
488 }
489 }
490 }
491 }
492 }
493 WindowEvent::CursorMoved { position, .. } => {
494 for (_, mapping) in self.mappings_stack.iter().rev() {
495 if let Some(mapping) = mapping.read() {
496 let mut consume = mapping.consume == InputConsume::All;
497 for (id, data) in &mapping.axes {
498 match id {
499 VirtualAxis::MousePositionX => {
500 if let Some(mut data) = data.write() {
501 data.0 = position.x as _;
502 if mapping.consume == InputConsume::Hit {
503 consume = true;
504 }
505 }
506 }
507 VirtualAxis::MousePositionY => {
508 if let Some(mut data) = data.write() {
509 data.0 = position.y as _;
510 if mapping.consume == InputConsume::Hit {
511 consume = true;
512 }
513 }
514 }
515 _ => {}
516 }
517 }
518 if consume {
519 break;
520 }
521 }
522 }
523 }
524 WindowEvent::MouseWheel { delta, .. } => {
525 for (_, mapping) in self.mappings_stack.iter().rev() {
526 if let Some(mapping) = mapping.read() {
527 let mut consume = mapping.consume == InputConsume::All;
528 for (id, data) in &mapping.axes {
529 match id {
530 VirtualAxis::MouseWheelX => {
531 if let Some(mut data) = data.write() {
532 data.0 = match delta {
533 MouseScrollDelta::LineDelta(x, _) => *x,
534 MouseScrollDelta::PixelDelta(pos) => pos.x as _,
535 };
536 if mapping.consume == InputConsume::Hit {
537 consume = true;
538 }
539 }
540 }
541 VirtualAxis::MouseWheelY => {
542 if let Some(mut data) = data.write() {
543 data.0 = match delta {
544 MouseScrollDelta::LineDelta(_, y) => *y,
545 MouseScrollDelta::PixelDelta(pos) => pos.y as _,
546 };
547 if mapping.consume == InputConsume::Hit {
548 consume = true;
549 }
550 }
551 }
552 _ => {}
553 }
554 }
555 if consume {
556 break;
557 }
558 }
559 }
560 }
561 WindowEvent::MouseInput { state, button, .. } => {
562 for (_, mapping) in self.mappings_stack.iter().rev() {
563 if let Some(mapping) = mapping.read() {
564 let mut consume = mapping.consume == InputConsume::All;
565 for (id, data) in &mapping.actions {
566 if let VirtualAction::MouseButton(btn) = id {
567 if button == btn {
568 if let Some(mut data) = data.write() {
569 *data = data.change(*state == ElementState::Pressed);
570 if mapping.consume == InputConsume::Hit {
571 consume = true;
572 }
573 }
574 }
575 }
576 }
577 for (id, data) in &mapping.axes {
578 if let VirtualAxis::MouseButton(btn) = id {
579 if button == btn {
580 if let Some(mut data) = data.write() {
581 data.0 = if *state == ElementState::Pressed {
582 1.0
583 } else {
584 0.0
585 };
586 if mapping.consume == InputConsume::Hit {
587 consume = true;
588 }
589 }
590 }
591 }
592 }
593 if consume {
594 break;
595 }
596 }
597 }
598 }
599 WindowEvent::AxisMotion { axis, value, .. } => {
600 for (_, mapping) in self.mappings_stack.iter().rev() {
601 if let Some(mapping) = mapping.read() {
602 let mut consume = mapping.consume == InputConsume::All;
603 for (id, data) in &mapping.actions {
604 if let VirtualAction::Axis(index) = id {
605 if axis == index {
606 if let Some(mut data) = data.write() {
607 *data = data.change(value.abs() > 0.5);
608 if mapping.consume == InputConsume::Hit {
609 consume = true;
610 }
611 }
612 }
613 }
614 }
615 for (id, data) in &mapping.axes {
616 if let VirtualAxis::Axis(index) = id {
617 if axis == index {
618 if let Some(mut data) = data.write() {
619 data.0 = *value as _;
620 if mapping.consume == InputConsume::Hit {
621 consume = true;
622 }
623 }
624 }
625 }
626 }
627 if consume {
628 break;
629 }
630 }
631 }
632 }
633 _ => {}
634 }
635 }
636}
637
638#[cfg(test)]
639mod tests {
640 use crate::{InputContext, InputMapping};
641
642 #[test]
643 fn test_stack() {
644 let mut context = InputContext::default();
645 context.push_mapping(InputMapping::default().name("a").layer(0));
646 context.push_mapping(InputMapping::default().name("b").layer(0));
647 context.push_mapping(InputMapping::default().name("c").layer(0));
648 context.push_mapping(InputMapping::default().name("d").layer(-1));
649 context.push_mapping(InputMapping::default().name("e").layer(1));
650 context.push_mapping(InputMapping::default().name("f").layer(-1));
651 context.push_mapping(InputMapping::default().name("g").layer(1));
652 context.push_mapping(InputMapping::default().name("h").layer(-2));
653 context.push_mapping(InputMapping::default().name("i").layer(-2));
654 context.push_mapping(InputMapping::default().name("j").layer(2));
655 context.push_mapping(InputMapping::default().name("k").layer(2));
656
657 let provided = context
658 .stack()
659 .map(|mapping| {
660 let mapping = mapping.read().unwrap();
661 (mapping.name.as_ref().to_owned(), mapping.layer)
662 })
663 .collect::<Vec<_>>();
664 assert_eq!(
665 provided,
666 vec![
667 ("h".to_owned(), -2),
668 ("i".to_owned(), -2),
669 ("d".to_owned(), -1),
670 ("f".to_owned(), -1),
671 ("a".to_owned(), 0),
672 ("b".to_owned(), 0),
673 ("c".to_owned(), 0),
674 ("e".to_owned(), 1),
675 ("g".to_owned(), 1),
676 ("j".to_owned(), 2),
677 ("k".to_owned(), 2),
678 ]
679 );
680 }
681}