use bevy::{
ecs::{
entity::{Entity, MapEntities},
event::Event,
},
math::{Vec2, Vec3},
prelude::{EntityMapper, EventWriter, Query, Res},
utils::{HashMap, HashSet},
};
use serde::{Deserialize, Serialize};
use crate::{action_state::ActionKindData, prelude::ActionState, Actionlike};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum ActionDiff<A: Actionlike> {
Pressed {
action: A,
},
Released {
action: A,
},
AxisChanged {
action: A,
value: f32,
},
DualAxisChanged {
action: A,
axis_pair: Vec2,
},
TripleAxisChanged {
action: A,
axis_triple: Vec3,
},
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Event)]
pub struct ActionDiffEvent<A: Actionlike> {
pub owner: Option<Entity>,
pub action_diffs: Vec<ActionDiff<A>>,
}
impl<A: Actionlike> MapEntities for ActionDiffEvent<A> {
fn map_entities<M: EntityMapper>(&mut self, entity_mapper: &mut M) {
self.owner = self.owner.map(|entity| entity_mapper.map_entity(entity));
}
}
#[derive(Debug, PartialEq, Clone)]
pub struct SummarizedActionState<A: Actionlike> {
button_state_map: HashMap<Entity, HashMap<A, bool>>,
axis_state_map: HashMap<Entity, HashMap<A, f32>>,
dual_axis_state_map: HashMap<Entity, HashMap<A, Vec2>>,
triple_axis_state_map: HashMap<Entity, HashMap<A, Vec3>>,
}
impl<A: Actionlike> SummarizedActionState<A> {
pub fn all_entities(&self) -> HashSet<Entity> {
let mut entities = HashSet::new();
let button_entities = self.button_state_map.keys();
let axis_entities = self.axis_state_map.keys();
let dual_axis_entities = self.dual_axis_state_map.keys();
let triple_axis_entities = self.triple_axis_state_map.keys();
entities.extend(button_entities);
entities.extend(axis_entities);
entities.extend(dual_axis_entities);
entities.extend(triple_axis_entities);
entities
}
pub fn summarize(
global_action_state: Option<Res<ActionState<A>>>,
action_state_query: Query<(Entity, &ActionState<A>)>,
) -> Self {
let mut button_state_map = HashMap::default();
let mut axis_state_map = HashMap::default();
let mut dual_axis_state_map = HashMap::default();
let mut triple_axis_state_map = HashMap::default();
if let Some(global_action_state) = global_action_state {
let mut per_entity_button_state = HashMap::default();
let mut per_entity_axis_state = HashMap::default();
let mut per_entity_dual_axis_state = HashMap::default();
let mut per_entity_triple_axis_state = HashMap::default();
for (action, action_data) in global_action_state.all_action_data() {
match &action_data.kind_data {
ActionKindData::Button(button_data) => {
per_entity_button_state.insert(action.clone(), button_data.pressed());
}
ActionKindData::Axis(axis_data) => {
per_entity_axis_state.insert(action.clone(), axis_data.value);
}
ActionKindData::DualAxis(dual_axis_data) => {
per_entity_dual_axis_state.insert(action.clone(), dual_axis_data.pair);
}
ActionKindData::TripleAxis(triple_axis_data) => {
per_entity_triple_axis_state
.insert(action.clone(), triple_axis_data.triple);
}
}
}
button_state_map.insert(Entity::PLACEHOLDER, per_entity_button_state);
axis_state_map.insert(Entity::PLACEHOLDER, per_entity_axis_state);
dual_axis_state_map.insert(Entity::PLACEHOLDER, per_entity_dual_axis_state);
triple_axis_state_map.insert(Entity::PLACEHOLDER, per_entity_triple_axis_state);
}
for (entity, action_state) in action_state_query.iter() {
let mut per_entity_button_state = HashMap::default();
let mut per_entity_axis_state = HashMap::default();
let mut per_entity_dual_axis_state = HashMap::default();
let mut per_entity_triple_axis_state = HashMap::default();
for (action, action_data) in action_state.all_action_data() {
match &action_data.kind_data {
ActionKindData::Button(button_data) => {
per_entity_button_state.insert(action.clone(), button_data.pressed());
}
ActionKindData::Axis(axis_data) => {
per_entity_axis_state.insert(action.clone(), axis_data.value);
}
ActionKindData::DualAxis(dual_axis_data) => {
per_entity_dual_axis_state.insert(action.clone(), dual_axis_data.pair);
}
ActionKindData::TripleAxis(triple_axis_data) => {
per_entity_triple_axis_state
.insert(action.clone(), triple_axis_data.triple);
}
}
}
button_state_map.insert(entity, per_entity_button_state);
axis_state_map.insert(entity, per_entity_axis_state);
dual_axis_state_map.insert(entity, per_entity_dual_axis_state);
triple_axis_state_map.insert(entity, per_entity_triple_axis_state);
}
Self {
button_state_map,
axis_state_map,
dual_axis_state_map,
triple_axis_state_map,
}
}
pub fn button_diff(
action: A,
previous_button: Option<bool>,
current_button: Option<bool>,
) -> Option<ActionDiff<A>> {
let previous_button = previous_button.unwrap_or_default();
let current_button = current_button?;
(previous_button != current_button).then(|| {
if current_button {
ActionDiff::Pressed { action }
} else {
ActionDiff::Released { action }
}
})
}
pub fn axis_diff(
action: A,
previous_axis: Option<f32>,
current_axis: Option<f32>,
) -> Option<ActionDiff<A>> {
let previous_axis = previous_axis.unwrap_or_default();
let current_axis = current_axis?;
(previous_axis != current_axis).then(|| ActionDiff::AxisChanged {
action,
value: current_axis,
})
}
pub fn dual_axis_diff(
action: A,
previous_dual_axis: Option<Vec2>,
current_dual_axis: Option<Vec2>,
) -> Option<ActionDiff<A>> {
let previous_dual_axis = previous_dual_axis.unwrap_or_default();
let current_dual_axis = current_dual_axis?;
(previous_dual_axis != current_dual_axis).then(|| ActionDiff::DualAxisChanged {
action,
axis_pair: current_dual_axis,
})
}
pub fn triple_axis_diff(
action: A,
previous_triple_axis: Option<Vec3>,
current_triple_axis: Option<Vec3>,
) -> Option<ActionDiff<A>> {
let previous_triple_axis = previous_triple_axis.unwrap_or_default();
let current_triple_axis = current_triple_axis?;
(previous_triple_axis != current_triple_axis).then(|| ActionDiff::TripleAxisChanged {
action,
axis_triple: current_triple_axis,
})
}
pub fn entity_diffs(&self, entity: &Entity, previous: &Self) -> Vec<ActionDiff<A>> {
let mut action_diffs = Vec::new();
if let Some(current_button_state) = self.button_state_map.get(entity) {
let previous_button_state = previous.button_state_map.get(entity);
for (action, current_button) in current_button_state {
let previous_button = previous_button_state
.and_then(|previous_button_state| previous_button_state.get(action))
.copied();
if let Some(diff) =
Self::button_diff(action.clone(), previous_button, Some(*current_button))
{
action_diffs.push(diff);
}
}
}
if let Some(current_axis_state) = self.axis_state_map.get(entity) {
let previous_axis_state = previous.axis_state_map.get(entity);
for (action, current_axis) in current_axis_state {
let previous_axis = previous_axis_state
.and_then(|previous_axis_state| previous_axis_state.get(action))
.copied();
if let Some(diff) =
Self::axis_diff(action.clone(), previous_axis, Some(*current_axis))
{
action_diffs.push(diff);
}
}
}
if let Some(current_dual_axis_state) = self.dual_axis_state_map.get(entity) {
let previous_dual_axis_state = previous.dual_axis_state_map.get(entity);
for (action, current_dual_axis) in current_dual_axis_state {
let previous_dual_axis = previous_dual_axis_state
.and_then(|previous_dual_axis_state| previous_dual_axis_state.get(action))
.copied();
if let Some(diff) = Self::dual_axis_diff(
action.clone(),
previous_dual_axis,
Some(*current_dual_axis),
) {
action_diffs.push(diff);
}
}
}
if let Some(current_triple_axis_state) = self.triple_axis_state_map.get(entity) {
let previous_triple_axis_state = previous.triple_axis_state_map.get(entity);
for (action, current_triple_axis) in current_triple_axis_state {
let previous_triple_axis = previous_triple_axis_state
.and_then(|previous_triple_axis_state| previous_triple_axis_state.get(action))
.copied();
if let Some(diff) = Self::triple_axis_diff(
action.clone(),
previous_triple_axis,
Some(*current_triple_axis),
) {
action_diffs.push(diff);
}
}
}
action_diffs
}
pub fn send_diffs(&self, previous: &Self, writer: &mut EventWriter<ActionDiffEvent<A>>) {
for entity in self.all_entities() {
let owner = (entity != Entity::PLACEHOLDER).then_some(entity);
let action_diffs = self.entity_diffs(&entity, previous);
if !action_diffs.is_empty() {
writer.send(ActionDiffEvent {
owner,
action_diffs,
});
}
}
}
}
impl<A: Actionlike> Default for SummarizedActionState<A> {
fn default() -> Self {
Self {
button_state_map: Default::default(),
axis_state_map: Default::default(),
dual_axis_state_map: Default::default(),
triple_axis_state_map: Default::default(),
}
}
}
#[cfg(test)]
mod tests {
use crate as leafwing_input_manager;
use super::*;
use bevy::{ecs::system::SystemState, prelude::*};
#[derive(Actionlike, Debug, Clone, Copy, PartialEq, Eq, Hash, Reflect)]
enum TestAction {
Button,
#[actionlike(Axis)]
Axis,
#[actionlike(DualAxis)]
DualAxis,
#[actionlike(TripleAxis)]
TripleAxis,
}
fn test_action_state() -> ActionState<TestAction> {
let mut action_state = ActionState::default();
action_state.press(&TestAction::Button);
action_state.set_value(&TestAction::Axis, 0.3);
action_state.set_axis_pair(&TestAction::DualAxis, Vec2::new(0.5, 0.7));
action_state.set_axis_triple(&TestAction::TripleAxis, Vec3::new(0.5, 0.7, 0.9));
action_state
}
fn expected_summary(entity: Entity) -> SummarizedActionState<TestAction> {
let mut button_state_map = HashMap::default();
let mut axis_state_map = HashMap::default();
let mut dual_axis_state_map = HashMap::default();
let mut triple_axis_state_map = HashMap::default();
let mut global_button_state = HashMap::default();
global_button_state.insert(TestAction::Button, true);
button_state_map.insert(entity, global_button_state);
let mut global_axis_state = HashMap::default();
global_axis_state.insert(TestAction::Axis, 0.3);
axis_state_map.insert(entity, global_axis_state);
let mut global_dual_axis_state = HashMap::default();
global_dual_axis_state.insert(TestAction::DualAxis, Vec2::new(0.5, 0.7));
dual_axis_state_map.insert(entity, global_dual_axis_state);
let mut global_triple_axis_state = HashMap::default();
global_triple_axis_state.insert(TestAction::TripleAxis, Vec3::new(0.5, 0.7, 0.9));
triple_axis_state_map.insert(entity, global_triple_axis_state);
SummarizedActionState {
button_state_map,
axis_state_map,
dual_axis_state_map,
triple_axis_state_map,
}
}
#[test]
fn summarize_from_resource() {
let mut world = World::new();
world.insert_resource(test_action_state());
let mut system_state: SystemState<(
Option<Res<ActionState<TestAction>>>,
Query<(Entity, &ActionState<TestAction>)>,
)> = SystemState::new(&mut world);
let (global_action_state, action_state_query) = system_state.get(&world);
let summarized = SummarizedActionState::summarize(global_action_state, action_state_query);
assert_eq!(summarized, expected_summary(Entity::PLACEHOLDER));
}
#[test]
fn summarize_from_component() {
let mut world = World::new();
let entity = world.spawn(test_action_state()).id();
let mut system_state: SystemState<(
Option<Res<ActionState<TestAction>>>,
Query<(Entity, &ActionState<TestAction>)>,
)> = SystemState::new(&mut world);
let (global_action_state, action_state_query) = system_state.get(&world);
let summarized = SummarizedActionState::summarize(global_action_state, action_state_query);
assert_eq!(summarized, expected_summary(entity));
}
#[test]
fn diffs_are_sent() {
let mut world = World::new();
world.init_resource::<Events<ActionDiffEvent<TestAction>>>();
let entity = world.spawn(test_action_state()).id();
let mut system_state: SystemState<(
Option<Res<ActionState<TestAction>>>,
Query<(Entity, &ActionState<TestAction>)>,
EventWriter<ActionDiffEvent<TestAction>>,
)> = SystemState::new(&mut world);
let (global_action_state, action_state_query, mut action_diff_writer) =
system_state.get_mut(&mut world);
let summarized = SummarizedActionState::summarize(global_action_state, action_state_query);
let previous = SummarizedActionState::default();
summarized.send_diffs(&previous, &mut action_diff_writer);
let mut system_state: SystemState<EventReader<ActionDiffEvent<TestAction>>> =
SystemState::new(&mut world);
let mut event_reader = system_state.get_mut(&mut world);
let action_diff_events = event_reader.read().collect::<Vec<_>>();
dbg!(&action_diff_events);
assert_eq!(action_diff_events.len(), 1);
let action_diff_event = action_diff_events[0];
assert_eq!(action_diff_event.owner, Some(entity));
assert_eq!(action_diff_event.action_diffs.len(), 4);
}
}