1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
use async_graphql::{
    dynamic::{Field, FieldFuture, FieldValue, InputValue, Object, ResolverContext, TypeRef},
    Context, FieldResult,
};
use dynamic_graphql::{
    internal::{OutputTypeName, Register, Registry, ResolveOwned, TypeName},
    SimpleObject,
};
use once_cell::sync::Lazy;
use ordered_float::OrderedFloat;
use raphtory::{
    algorithms::pagerank::unweighted_page_rank,
    db::api::view::{internal::DynamicGraph, GraphViewOps},
};
use std::{borrow::Cow, collections::HashMap, sync::Mutex};

type RegisterFunction = fn(&str, Registry, Object) -> (Registry, Object);

pub(crate) static PLUGIN_ALGOS: Lazy<Mutex<HashMap<String, RegisterFunction>>> =
    Lazy::new(|| Mutex::new(HashMap::new()));

pub(crate) struct Algorithms {
    graph: DynamicGraph,
}

impl From<DynamicGraph> for Algorithms {
    fn from(graph: DynamicGraph) -> Self {
        Self { graph }
    }
}

impl Register for Algorithms {
    fn register(registry: Registry) -> Registry {
        let mut registry = registry;
        let mut object = Object::new("Algorithms");

        let algos = HashMap::from([("pagerank", Pagerank::register_algo)]);
        for (name, register_algo) in algos {
            (registry, object) = register_algo(name, registry, object);
        }

        for (name, register_algo) in PLUGIN_ALGOS.lock().unwrap().iter() {
            (registry, object) = register_algo(name, registry, object);
        }

        registry.register_type(object)
    }
}

impl TypeName for Algorithms {
    fn get_type_name() -> Cow<'static, str> {
        "Algorithms".into()
    }
}

impl OutputTypeName for Algorithms {}

impl<'a> ResolveOwned<'a> for Algorithms {
    fn resolve_owned(self, _ctx: &Context) -> dynamic_graphql::Result<Option<FieldValue<'a>>> {
        Ok(Some(FieldValue::owned_any(self)))
    }
}

pub trait Algorithm: Register + 'static {
    fn output_type() -> TypeRef;
    fn args<'a>() -> Vec<(&'a str, TypeRef)>;
    fn apply_algo<'a, G: GraphViewOps>(
        graph: &G,
        ctx: ResolverContext,
    ) -> FieldResult<Option<FieldValue<'a>>>;
    fn register_algo(name: &str, registry: Registry, parent: Object) -> (Registry, Object) {
        let registry = registry.register::<Self>();
        let mut field = Field::new(name, Self::output_type(), |ctx| {
            FieldFuture::new(async move {
                let algos: &Algorithms = ctx.parent_value.downcast_ref().unwrap();
                Self::apply_algo(&algos.graph, ctx)
            })
        });
        for (name, type_ref) in Self::args() {
            field = field.argument(InputValue::new(name, type_ref));
        }
        let parent = parent.field(field);
        (registry, parent)
    }
}

#[derive(SimpleObject)]
struct Pagerank {
    name: String,
    rank: f64,
}

impl From<(String, f64)> for Pagerank {
    fn from((name, rank): (String, f64)) -> Self {
        Self { name, rank }
    }
}

impl From<(&String, &OrderedFloat<f64>)> for Pagerank {
    fn from((name, rank): (&String, &OrderedFloat<f64>)) -> Self {
        Self {
            name: name.to_string(),
            rank: rank.into_inner(),
        }
    }
}

impl Algorithm for Pagerank {
    fn output_type() -> TypeRef {
        // first _nn means that the list is never null, second _nn means no element is null
        TypeRef::named_nn_list_nn(Self::get_type_name()) //
    }
    fn args<'a>() -> Vec<(&'a str, TypeRef)> {
        vec![
            ("iterCount", TypeRef::named_nn(TypeRef::INT)), // _nn stands for not null
            ("threads", TypeRef::named(TypeRef::INT)),      // this one though might be null
            ("tol", TypeRef::named(TypeRef::FLOAT)),
        ]
    }
    fn apply_algo<'a, G: GraphViewOps>(
        graph: &G,
        ctx: ResolverContext,
    ) -> FieldResult<Option<FieldValue<'a>>> {
        let iter_count = ctx.args.try_get("iterCount")?.u64()? as usize;
        let threads = ctx.args.get("threads").map(|v| v.u64()).transpose()?;
        let threads = threads.map(|v| v as usize);
        let tol = ctx.args.get("tol").map(|v| v.f64()).transpose()?;
        let binding = unweighted_page_rank(graph, iter_count, threads, tol, true);
        let result = binding
            .into_iter()
            .map(|pair| FieldValue::owned_any(Pagerank::from(pair)));
        Ok(Some(FieldValue::list(result)))
    }
}