1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
use crate::model::DynamicGraph;
use async_graphql::dynamic::{
    Field, FieldFuture, FieldValue, InputValue, Object, ResolverContext, TypeRef,
};
use async_graphql::{Context, FieldResult};
use dynamic_graphql::internal::{OutputTypeName, Register, Registry, ResolveOwned, TypeName};
use dynamic_graphql::SimpleObject;
use once_cell::sync::Lazy;
use raphtory::algorithms::pagerank::unweighted_page_rank;
use raphtory::db::view_api::GraphViewOps;
use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::Mutex;

type RegisterFunction = fn(&str, Registry, Object) -> (Registry, Object);

pub(crate) static PLUGIN_ALGOS: Lazy<Mutex<HashMap<String, RegisterFunction>>> =
    Lazy::new(|| Mutex::new(HashMap::new()));

pub(crate) struct Algorithms {
    graph: DynamicGraph,
}

impl From<DynamicGraph> for Algorithms {
    fn from(graph: DynamicGraph) -> Self {
        Self { graph }
    }
}

impl Register for Algorithms {
    fn register(registry: Registry) -> Registry {
        let mut registry = registry;
        let mut object = Object::new("Algorithms");

        let algos = HashMap::from([("pagerank", Pagerank::register_algo)]);
        for (name, register_algo) in algos {
            (registry, object) = register_algo(name, registry, object);
        }

        for (name, register_algo) in PLUGIN_ALGOS.lock().unwrap().iter() {
            (registry, object) = register_algo(name, registry, object);
        }

        registry.register_type(object)
    }
}

impl TypeName for Algorithms {
    fn get_type_name() -> Cow<'static, str> {
        "Algorithms".into()
    }
}

impl OutputTypeName for Algorithms {}

impl<'a> ResolveOwned<'a> for Algorithms {
    fn resolve_owned(self, _ctx: &Context) -> dynamic_graphql::Result<Option<FieldValue<'a>>> {
        Ok(Some(FieldValue::owned_any(self)))
    }
}

pub trait Algorithm: Register + 'static {
    fn output_type() -> TypeRef;
    fn args<'a>() -> Vec<(&'a str, TypeRef)>;
    fn apply_algo<'a, G: GraphViewOps>(
        graph: &G,
        ctx: ResolverContext,
    ) -> FieldResult<Option<FieldValue<'a>>>;
    fn register_algo(name: &str, registry: Registry, parent: Object) -> (Registry, Object) {
        let registry = registry.register::<Self>();
        let mut field = Field::new(name, Self::output_type(), |ctx| {
            FieldFuture::new(async move {
                let algos: &Algorithms = ctx.parent_value.downcast_ref().unwrap();
                Self::apply_algo(&algos.graph, ctx)
            })
        });
        for (name, type_ref) in Self::args() {
            field = field.argument(InputValue::new(name, type_ref));
        }
        let parent = parent.field(field);
        (registry, parent)
    }
}

#[derive(SimpleObject)]
struct Pagerank {
    name: String,
    rank: f64,
}

impl From<(String, f64)> for Pagerank {
    fn from((name, rank): (String, f64)) -> Self {
        Self { name, rank }
    }
}

impl Algorithm for Pagerank {
    fn output_type() -> TypeRef {
        // first _nn means that the list is never null, second _nn means no element is null
        TypeRef::named_nn_list_nn(Self::get_type_name()) //
    }
    fn args<'a>() -> Vec<(&'a str, TypeRef)> {
        vec![
            ("iterCount", TypeRef::named_nn(TypeRef::INT)), // _nn stands for not null
            ("threads", TypeRef::named(TypeRef::INT)),      // this one though might be null
            ("tol", TypeRef::named(TypeRef::FLOAT)),
        ]
    }
    fn apply_algo<'a, G: GraphViewOps>(
        graph: &G,
        ctx: ResolverContext,
    ) -> FieldResult<Option<FieldValue<'a>>> {
        let iter_count = ctx.args.try_get("iterCount")?.u64()? as usize;
        let threads = ctx.args.get("threads").map(|v| v.u64()).transpose()?;
        let threads = threads.map(|v| v as usize);
        let tol = ctx.args.get("tol").map(|v| v.f64()).transpose()?;
        let result = unweighted_page_rank(graph, iter_count, threads, tol, true)
            .into_iter()
            .map(|pair| FieldValue::owned_any(Pagerank::from(pair)));
        Ok(Some(FieldValue::list(result)))
    }
}