use std::fmt;
use beef::Cow;
use serde::de::{self, Deserializer, Unexpected, Visitor};
use serde::ser::Serializer;
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use crate::error::{ErrorCode, INVALID_PARAMS_MSG};
use crate::{ErrorObject, ErrorObjectOwned};
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct TwoPointZero;
struct TwoPointZeroVisitor;
impl<'de> Visitor<'de> for TwoPointZeroVisitor {
type Value = TwoPointZero;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str(r#"a string "2.0""#)
}
fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
match s {
"2.0" => Ok(TwoPointZero),
_ => Err(de::Error::invalid_value(Unexpected::Str(s), &self)),
}
}
}
impl<'de> Deserialize<'de> for TwoPointZero {
fn deserialize<D>(deserializer: D) -> Result<TwoPointZero, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_str(TwoPointZeroVisitor)
}
}
impl Serialize for TwoPointZero {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str("2.0")
}
}
#[derive(Clone, Debug)]
pub struct Params<'a>(Option<Cow<'a, str>>);
impl<'a> Params<'a> {
pub fn new(raw: Option<&'a str>) -> Self {
Self(raw.map(|r| r.trim().into()))
}
pub fn is_object(&self) -> bool {
let json: &str = match self.0 {
Some(ref cow) => cow,
None => return false,
};
json.starts_with('{')
}
pub fn sequence(&self) -> ParamsSequence {
let json = match self.0.as_ref() {
Some(json) if json == "[]" => "",
Some(json) => json,
None => "",
};
ParamsSequence(json)
}
pub fn parse<T>(&'a self) -> Result<T, ErrorObjectOwned>
where
T: Deserialize<'a>,
{
let params = self.0.as_ref().map(AsRef::as_ref).unwrap_or("null");
serde_json::from_str(params).map_err(invalid_params)
}
pub fn one<T>(&'a self) -> Result<T, ErrorObjectOwned>
where
T: Deserialize<'a>,
{
self.parse::<[T; 1]>().map(|[res]| res)
}
pub fn into_owned(self) -> Params<'static> {
Params(self.0.map(|s| Cow::owned(s.into_owned())))
}
pub fn len_bytes(&self) -> usize {
match self.0 {
Some(ref cow) => cow.len(),
None => 0,
}
}
pub fn as_str(&self) -> Option<&str> {
match self.0 {
Some(ref cow) => Some(cow.as_ref()),
None => None,
}
}
}
#[derive(Debug, Copy, Clone)]
pub struct ParamsSequence<'a>(&'a str);
impl<'a> ParamsSequence<'a> {
fn next_inner<T>(&mut self) -> Option<Result<T, ErrorObjectOwned>>
where
T: Deserialize<'a>,
{
let mut json = self.0;
match json.as_bytes().first()? {
b']' => {
self.0 = "";
return None;
}
b'[' | b',' => json = &json[1..],
_ => {
let errmsg = format!("Invalid params. Expected one of '[', ']' or ',' but found {json:?}");
return Some(Err(invalid_params(errmsg)));
}
}
let mut iter = serde_json::Deserializer::from_str(json).into_iter::<T>();
match iter.next()? {
Ok(value) => {
self.0 = json[iter.byte_offset()..].trim_start();
Some(Ok(value))
}
Err(e) => {
self.0 = "";
Some(Err(invalid_params(e)))
}
}
}
#[allow(clippy::should_implement_trait)]
pub fn next<T>(&mut self) -> Result<T, ErrorObjectOwned>
where
T: Deserialize<'a>,
{
match self.next_inner() {
Some(result) => result,
None => Err(invalid_params("No more params")),
}
}
pub fn optional_next<T>(&mut self) -> Result<Option<T>, ErrorObjectOwned>
where
T: Deserialize<'a>,
{
match self.next_inner::<Option<T>>() {
Some(result) => result,
None => Ok(None),
}
}
}
#[derive(Debug, PartialEq, Clone, Hash, Eq, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
#[serde(untagged)]
pub enum SubscriptionId<'a> {
Num(u64),
#[serde(borrow)]
Str(Cow<'a, str>),
}
impl<'a> From<SubscriptionId<'a>> for JsonValue {
fn from(sub_id: SubscriptionId) -> Self {
match sub_id {
SubscriptionId::Num(n) => n.into(),
SubscriptionId::Str(s) => s.into_owned().into(),
}
}
}
impl<'a> From<u64> for SubscriptionId<'a> {
fn from(sub_id: u64) -> Self {
Self::Num(sub_id)
}
}
impl<'a> From<String> for SubscriptionId<'a> {
fn from(sub_id: String) -> Self {
Self::Str(sub_id.into())
}
}
impl<'a> TryFrom<JsonValue> for SubscriptionId<'a> {
type Error = ();
fn try_from(json: JsonValue) -> Result<SubscriptionId<'a>, ()> {
match json {
JsonValue::String(s) => Ok(SubscriptionId::Str(s.into())),
JsonValue::Number(n) => {
if let Some(n) = n.as_u64() {
Ok(SubscriptionId::Num(n))
} else {
Err(())
}
}
_ => Err(()),
}
}
}
impl<'a> SubscriptionId<'a> {
pub fn into_owned(self) -> SubscriptionId<'static> {
match self {
SubscriptionId::Num(num) => SubscriptionId::Num(num),
SubscriptionId::Str(s) => SubscriptionId::Str(Cow::owned(s.into_owned())),
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum InvalidRequestId {
#[error("request ID={0} is not a pending call")]
NotPendingRequest(String),
#[error("request ID={0} is already occupied by a pending call")]
Occupied(String),
#[error("request ID={0} is invalid")]
Invalid(String),
}
#[derive(Debug, PartialEq, Clone, Hash, Eq, Deserialize, Serialize, PartialOrd, Ord)]
#[serde(deny_unknown_fields)]
#[serde(untagged)]
pub enum Id<'a> {
Null,
Number(u64),
#[serde(borrow)]
Str(Cow<'a, str>),
}
impl<'a> Id<'a> {
pub fn as_number(&self) -> Option<&u64> {
match self {
Self::Number(n) => Some(n),
_ => None,
}
}
pub fn as_str(&self) -> Option<&str> {
match self {
Self::Str(s) => Some(s.as_ref()),
_ => None,
}
}
pub fn as_null(&self) -> Option<()> {
match self {
Self::Null => Some(()),
_ => None,
}
}
pub fn into_owned(self) -> Id<'static> {
match self {
Id::Null => Id::Null,
Id::Number(num) => Id::Number(num),
Id::Str(s) => Id::Str(Cow::owned(s.into_owned())),
}
}
pub fn try_parse_inner_as_number(&self) -> Result<u64, InvalidRequestId> {
match self {
Id::Null => Err(InvalidRequestId::Invalid("null".to_string())),
Id::Number(num) => Ok(*num),
Id::Str(s) => s.parse().map_err(|_| InvalidRequestId::Invalid(s.as_ref().to_owned())),
}
}
}
impl<'a> std::fmt::Display for Id<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Id::Null => f.write_str("null"),
Id::Number(n) => f.write_str(&n.to_string()),
Id::Str(s) => f.write_str(s),
}
}
}
fn invalid_params(e: impl ToString) -> ErrorObjectOwned {
ErrorObject::owned(ErrorCode::InvalidParams.code(), INVALID_PARAMS_MSG, Some(e.to_string()))
}
#[cfg(test)]
mod test {
use super::{Cow, Id, JsonValue, Params, SubscriptionId, TwoPointZero};
use crate::response::SubscriptionPayload;
#[test]
fn id_deserialization() {
let s = r#""2""#;
let deserialized: Id = serde_json::from_str(s).unwrap();
match deserialized {
Id::Str(ref cow) => {
assert!(cow.is_borrowed());
assert_eq!(cow, "2");
}
_ => panic!("Expected Id::Str"),
}
let s = r#"2"#;
let deserialized: Id = serde_json::from_str(s).unwrap();
assert_eq!(deserialized, Id::Number(2));
let s = r#""2x""#;
let deserialized: Id = serde_json::from_str(s).unwrap();
assert_eq!(deserialized, Id::Str(Cow::const_str("2x")));
let s = r#"[1337]"#;
assert!(serde_json::from_str::<Id>(s).is_err());
let s = r#"[null, 0, 2, "\"3"]"#;
let deserialized: Vec<Id> = serde_json::from_str(s).unwrap();
assert_eq!(deserialized, vec![Id::Null, Id::Number(0), Id::Number(2), Id::Str("\"3".into())]);
}
#[test]
fn id_serialization() {
let d =
vec![Id::Null, Id::Number(0), Id::Number(2), Id::Number(3), Id::Str("\"3".into()), Id::Str("test".into())];
let serialized = serde_json::to_string(&d).unwrap();
assert_eq!(serialized, r#"[null,0,2,3,"\"3","test"]"#);
}
#[test]
fn params_parse() {
let none = Params::new(None);
assert!(none.sequence().next::<u64>().is_err());
assert!(none.parse::<Option<u64>>().is_ok());
assert_eq!(none.len_bytes(), 0);
let array_params = Params::new(Some("[1, 2, 3]"));
assert_eq!(array_params.len_bytes(), 9);
let arr: Result<[u64; 3], _> = array_params.parse();
assert!(arr.is_ok());
let mut seq = array_params.sequence();
assert_eq!(seq.next::<u64>().unwrap(), 1);
assert_eq!(seq.next::<u64>().unwrap(), 2);
assert_eq!(seq.next::<u64>().unwrap(), 3);
assert!(seq.next::<u64>().is_err());
let array_one = Params::new(Some("[1]"));
assert_eq!(array_one.len_bytes(), 3);
let one: Result<u64, _> = array_one.one();
assert!(one.is_ok());
let object_params = Params::new(Some(r#"{"beef":99,"dinner":0}"#));
assert_eq!(object_params.len_bytes(), 22);
let obj: Result<JsonValue, _> = object_params.parse();
assert!(obj.is_ok());
}
#[test]
fn params_parse_empty_json() {
let array_params = Params::new(Some("[]"));
let arr: Result<Vec<u64>, _> = array_params.parse();
assert!(arr.is_ok());
let obj_params = Params::new(Some("{}"));
let obj: Result<JsonValue, _> = obj_params.parse();
assert!(obj.is_ok());
}
#[test]
fn params_sequence_borrows() {
let params = Params::new(Some(r#"["foo", "bar"]"#));
let mut seq = params.sequence();
assert_eq!(seq.next::<&str>().unwrap(), "foo");
assert_eq!(seq.next::<&str>().unwrap(), "bar");
assert!(seq.next::<&str>().is_err());
let params: (&str, &str) = params.parse().unwrap();
assert_eq!(params, ("foo", "bar"));
}
#[test]
fn two_point_zero_serde_works() {
let initial_ser = r#""2.0""#;
let two_point_zero: TwoPointZero = serde_json::from_str(initial_ser).unwrap();
let serialized = serde_json::to_string(&two_point_zero).unwrap();
assert_eq!(serialized, initial_ser);
}
#[test]
fn subscription_id_serde_works() {
let test_vector = &[("42", SubscriptionId::Num(42)), (r#""one""#, SubscriptionId::Str("one".into()))];
for (initial_ser, expected) in test_vector {
let id: SubscriptionId = serde_json::from_str(initial_ser).unwrap();
assert_eq!(&id, expected);
let serialized = serde_json::to_string(&id).unwrap();
assert_eq!(&serialized, initial_ser);
}
}
#[test]
fn subscription_params_serialize_work() {
let ser = serde_json::to_string(&SubscriptionPayload { subscription: SubscriptionId::Num(12), result: "goal" })
.unwrap();
let exp = r#"{"subscription":12,"result":"goal"}"#;
assert_eq!(ser, exp);
}
#[test]
fn subscription_params_deserialize_work() {
let ser = r#"{"subscription":"9","result":"offside"}"#;
assert!(
serde_json::from_str::<SubscriptionPayload<()>>(ser).is_err(),
"invalid type should not be deserializable"
);
let dsr: SubscriptionPayload<JsonValue> = serde_json::from_str(ser).unwrap();
assert_eq!(dsr.subscription, SubscriptionId::Str("9".into()));
assert_eq!(dsr.result, serde_json::json!("offside"));
}
#[test]
fn params_sequence_optional_ignore_empty() {
let params = Params::new(Some(r#"["foo", "bar"]"#));
let mut seq = params.sequence();
assert_eq!(seq.optional_next::<&str>().unwrap(), Some("foo"));
assert_eq!(seq.optional_next::<&str>().unwrap(), Some("bar"));
let params = Params::new(Some(r#"[]"#));
let mut seq = params.sequence();
assert!(seq.optional_next::<&str>().unwrap().is_none());
let params = Params::new(Some(r#" [] "#));
let mut seq = params.sequence();
assert!(seq.optional_next::<&str>().unwrap().is_none());
let params = Params::new(Some(r#"{}"#));
let mut seq = params.sequence();
assert!(seq.optional_next::<&str>().is_err(), "JSON object not supported by RpcSequence");
let params = Params::new(Some(r#"[12, "[]", [], {}]"#));
let mut seq = params.sequence();
assert_eq!(seq.optional_next::<u64>().unwrap(), Some(12));
assert_eq!(seq.optional_next::<&str>().unwrap(), Some("[]"));
assert_eq!(seq.optional_next::<Vec<u8>>().unwrap(), Some(vec![]));
assert_eq!(seq.optional_next::<serde_json::Value>().unwrap(), Some(serde_json::json!({})));
}
#[test]
fn params_sequence_optional_nesting_works() {
let nested = Params::new(Some(r#"[1, [2], [3, 4], [[5], [6,7], []], {"named":7}]"#));
let mut seq = nested.sequence();
assert_eq!(seq.optional_next::<i8>().unwrap(), Some(1));
assert_eq!(seq.optional_next::<[i8; 1]>().unwrap(), Some([2]));
assert_eq!(seq.optional_next::<Vec<u16>>().unwrap(), Some(vec![3, 4]));
assert_eq!(seq.optional_next::<Vec<Vec<u32>>>().unwrap(), Some(vec![vec![5], vec![6, 7], vec![]]));
assert_eq!(seq.optional_next::<serde_json::Value>().unwrap(), Some(serde_json::json!({"named":7})));
}
}