pub mod bert;
pub mod byte_level;
pub mod precompiled;
pub mod prepend;
pub mod replace;
pub mod strip;
pub mod unicode;
pub mod utils;
pub use crate::normalizers::bert::BertNormalizer;
pub use crate::normalizers::byte_level::ByteLevel;
pub use crate::normalizers::precompiled::Precompiled;
pub use crate::normalizers::prepend::Prepend;
pub use crate::normalizers::replace::Replace;
pub use crate::normalizers::strip::{Strip, StripAccents};
pub use crate::normalizers::unicode::{Nmt, NFC, NFD, NFKC, NFKD};
pub use crate::normalizers::utils::{Lowercase, Sequence};
use serde::{Deserialize, Deserializer, Serialize};
use crate::{NormalizedString, Normalizer};
#[derive(Clone, Debug, Serialize)]
#[serde(untagged)]
pub enum NormalizerWrapper {
BertNormalizer(BertNormalizer),
StripNormalizer(Strip),
StripAccents(StripAccents),
NFC(NFC),
NFD(NFD),
NFKC(NFKC),
NFKD(NFKD),
Sequence(Sequence),
Lowercase(Lowercase),
Nmt(Nmt),
Precompiled(Precompiled),
Replace(Replace),
Prepend(Prepend),
ByteLevel(ByteLevel),
}
impl<'de> Deserialize<'de> for NormalizerWrapper {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
pub struct Tagged {
#[serde(rename = "type")]
variant: EnumType,
#[serde(flatten)]
rest: serde_json::Value,
}
#[derive(Serialize, Deserialize)]
pub enum EnumType {
Bert,
Strip,
StripAccents,
NFC,
NFD,
NFKC,
NFKD,
Sequence,
Lowercase,
Nmt,
Precompiled,
Replace,
Prepend,
ByteLevel,
}
#[derive(Deserialize)]
#[serde(untagged)]
pub enum NormalizerHelper {
Tagged(Tagged),
Legacy(serde_json::Value),
}
#[derive(Deserialize)]
#[serde(untagged)]
pub enum NormalizerUntagged {
BertNormalizer(BertNormalizer),
StripNormalizer(Strip),
StripAccents(StripAccents),
NFC(NFC),
NFD(NFD),
NFKC(NFKC),
NFKD(NFKD),
Sequence(Sequence),
Lowercase(Lowercase),
Nmt(Nmt),
Precompiled(Precompiled),
Replace(Replace),
Prepend(Prepend),
ByteLevel(ByteLevel),
}
let helper = NormalizerHelper::deserialize(deserializer)?;
Ok(match helper {
NormalizerHelper::Tagged(model) => {
let mut values: serde_json::Map<String, serde_json::Value> =
serde_json::from_value(model.rest).expect("Parsed values");
values.insert(
"type".to_string(),
serde_json::to_value(&model.variant).expect("Reinsert"),
);
let values = serde_json::Value::Object(values);
match model.variant {
EnumType::Bert => NormalizerWrapper::BertNormalizer(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::Strip => NormalizerWrapper::StripNormalizer(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::StripAccents => NormalizerWrapper::StripAccents(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::NFC => NormalizerWrapper::NFC(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::NFD => NormalizerWrapper::NFD(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::NFKC => NormalizerWrapper::NFKC(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::NFKD => NormalizerWrapper::NFKD(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::Sequence => NormalizerWrapper::Sequence(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::Lowercase => NormalizerWrapper::Lowercase(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::Nmt => NormalizerWrapper::Nmt(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::Precompiled => NormalizerWrapper::Precompiled(
serde_json::from_str(
&serde_json::to_string(&values).expect("Can reserialize precompiled"),
)
.expect("Precompiled"),
),
EnumType::Replace => NormalizerWrapper::Replace(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::Prepend => NormalizerWrapper::Prepend(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::ByteLevel => NormalizerWrapper::ByteLevel(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
}
}
NormalizerHelper::Legacy(value) => {
let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?;
match untagged {
NormalizerUntagged::BertNormalizer(bpe) => {
NormalizerWrapper::BertNormalizer(bpe)
}
NormalizerUntagged::StripNormalizer(bpe) => {
NormalizerWrapper::StripNormalizer(bpe)
}
NormalizerUntagged::StripAccents(bpe) => NormalizerWrapper::StripAccents(bpe),
NormalizerUntagged::NFC(bpe) => NormalizerWrapper::NFC(bpe),
NormalizerUntagged::NFD(bpe) => NormalizerWrapper::NFD(bpe),
NormalizerUntagged::NFKC(bpe) => NormalizerWrapper::NFKC(bpe),
NormalizerUntagged::NFKD(bpe) => NormalizerWrapper::NFKD(bpe),
NormalizerUntagged::Sequence(bpe) => NormalizerWrapper::Sequence(bpe),
NormalizerUntagged::Lowercase(bpe) => NormalizerWrapper::Lowercase(bpe),
NormalizerUntagged::Nmt(bpe) => NormalizerWrapper::Nmt(bpe),
NormalizerUntagged::Precompiled(bpe) => NormalizerWrapper::Precompiled(bpe),
NormalizerUntagged::Replace(bpe) => NormalizerWrapper::Replace(bpe),
NormalizerUntagged::Prepend(bpe) => NormalizerWrapper::Prepend(bpe),
NormalizerUntagged::ByteLevel(bpe) => NormalizerWrapper::ByteLevel(bpe),
}
}
})
}
}
impl Normalizer for NormalizerWrapper {
fn normalize(&self, normalized: &mut NormalizedString) -> crate::Result<()> {
match self {
Self::BertNormalizer(bn) => bn.normalize(normalized),
Self::StripNormalizer(sn) => sn.normalize(normalized),
Self::StripAccents(sn) => sn.normalize(normalized),
Self::NFC(nfc) => nfc.normalize(normalized),
Self::NFD(nfd) => nfd.normalize(normalized),
Self::NFKC(nfkc) => nfkc.normalize(normalized),
Self::NFKD(nfkd) => nfkd.normalize(normalized),
Self::Sequence(sequence) => sequence.normalize(normalized),
Self::Lowercase(lc) => lc.normalize(normalized),
Self::Nmt(lc) => lc.normalize(normalized),
Self::Precompiled(lc) => lc.normalize(normalized),
Self::Replace(lc) => lc.normalize(normalized),
Self::Prepend(lc) => lc.normalize(normalized),
Self::ByteLevel(lc) => lc.normalize(normalized),
}
}
}
impl_enum_from!(BertNormalizer, NormalizerWrapper, BertNormalizer);
impl_enum_from!(NFKD, NormalizerWrapper, NFKD);
impl_enum_from!(NFKC, NormalizerWrapper, NFKC);
impl_enum_from!(NFC, NormalizerWrapper, NFC);
impl_enum_from!(NFD, NormalizerWrapper, NFD);
impl_enum_from!(Strip, NormalizerWrapper, StripNormalizer);
impl_enum_from!(StripAccents, NormalizerWrapper, StripAccents);
impl_enum_from!(Sequence, NormalizerWrapper, Sequence);
impl_enum_from!(Lowercase, NormalizerWrapper, Lowercase);
impl_enum_from!(Nmt, NormalizerWrapper, Nmt);
impl_enum_from!(Precompiled, NormalizerWrapper, Precompiled);
impl_enum_from!(Replace, NormalizerWrapper, Replace);
impl_enum_from!(Prepend, NormalizerWrapper, Prepend);
impl_enum_from!(ByteLevel, NormalizerWrapper, ByteLevel);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn post_processor_deserialization_no_type() {
let json = r#"{"strip_left":false, "strip_right":true}"#;
let reconstructed = serde_json::from_str::<NormalizerWrapper>(json);
assert!(matches!(
reconstructed.unwrap(),
NormalizerWrapper::StripNormalizer(_)
));
let json = r#"{"trim_offsets":true, "add_prefix_space":true}"#;
let reconstructed = serde_json::from_str::<NormalizerWrapper>(json);
match reconstructed {
Err(err) => assert_eq!(
err.to_string(),
"data did not match any variant of untagged enum NormalizerUntagged"
),
_ => panic!("Expected an error here"),
}
let json = r#"{"prepend":"a"}"#;
let reconstructed = serde_json::from_str::<NormalizerWrapper>(json);
assert!(matches!(
reconstructed.unwrap(),
NormalizerWrapper::Prepend(_)
));
}
#[test]
fn normalizer_serialization() {
let json = r#"{"type":"Sequence","normalizers":[]}"#;
assert!(serde_json::from_str::<NormalizerWrapper>(json).is_ok());
let json = r#"{"type":"Sequence","normalizers":[{}]}"#;
let parse = serde_json::from_str::<NormalizerWrapper>(json);
match parse {
Err(err) => assert_eq!(
format!("{err}"),
"data did not match any variant of untagged enum NormalizerUntagged"
),
_ => panic!("Expected error"),
}
let json = r#"{"replacement":"▁","prepend_scheme":"always"}"#;
let parse = serde_json::from_str::<NormalizerWrapper>(json);
match parse {
Err(err) => assert_eq!(
format!("{err}"),
"data did not match any variant of untagged enum NormalizerUntagged"
),
_ => panic!("Expected error"),
}
let json = r#"{"type":"Sequence","prepend_scheme":"always"}"#;
let parse = serde_json::from_str::<NormalizerWrapper>(json);
match parse {
Err(err) => assert_eq!(format!("{err}"), "missing field `normalizers`"),
_ => panic!("Expected error"),
}
}
}