1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
use async_trait::async_trait;
use qdrant_client::prelude::*;
use serde::{Deserialize, Serialize};
use shuttle_service::{
    error::{CustomError, Error},
    resource::{ProvisionResourceRequest, Type},
    ContainerRequest, ContainerResponse, Environment, IntoResource, ResourceFactory,
    ResourceInputBuilder, ShuttleResourceOutput,
};

/// A Qdrant vector database
#[derive(Default, Serialize)]
pub struct Qdrant {
    /// Required if deploying
    cloud_url: Option<String>,
    /// Required if url endpoint is protected by key
    api_key: Option<String>,
    /// If given, use this instead of the default docker container on local run
    local_url: Option<String>,
}

impl Qdrant {
    pub fn cloud_url(mut self, cloud_url: &str) -> Self {
        self.cloud_url = Some(cloud_url.to_string());
        self
    }
    pub fn api_key(mut self, api_key: &str) -> Self {
        self.api_key = Some(api_key.to_string());
        self
    }
    pub fn local_url(mut self, local_url: &str) -> Self {
        self.local_url = Some(local_url.to_string());
        self
    }
}

/// Conditionally request a Shuttle resource
#[derive(Serialize, Deserialize)]
#[serde(untagged)]
pub enum MaybeRequest {
    Request(ProvisionResourceRequest),
    NotRequest(QdrantClientConfigWrap),
}

#[async_trait]
impl ResourceInputBuilder for Qdrant {
    type Input = MaybeRequest;
    // The response can be a provisioned container, depending on local/deployment and config.
    type Output = OutputWrapper;

    async fn build(self, factory: &ResourceFactory) -> Result<Self::Input, Error> {
        let md = factory.get_metadata();
        match md.env {
            Environment::Deployment => match self.cloud_url {
                Some(cloud_url) => Ok(MaybeRequest::NotRequest(QdrantClientConfigWrap {
                    url: cloud_url,
                    api_key: self.api_key,
                })),
                None => Err(Error::Custom(CustomError::msg(
                    "missing `cloud_url` parameter",
                ))),
            },
            Environment::Local => match self.local_url {
                Some(local_url) => Ok(MaybeRequest::NotRequest(QdrantClientConfigWrap {
                    url: local_url,
                    api_key: self.api_key,
                })),
                None => Ok(MaybeRequest::Request(ProvisionResourceRequest::new(
                    Type::Container,
                    serde_json::to_value(ContainerRequest {
                        project_name: md.project_name,
                        container_name: "qdrant".to_string(),
                        image: "docker.io/qdrant/qdrant:v1.7.4".to_string(),
                        port: "6334/tcp".to_string(),
                        env: vec![],
                    })
                    .unwrap(),
                    serde_json::Value::Null,
                ))),
            },
        }
    }
}

#[derive(Serialize, Deserialize)]
#[serde(untagged)]
pub enum OutputWrapper {
    Container(ShuttleResourceOutput<ContainerResponse>),
    Config(QdrantClientConfigWrap),
}

/// Scrappy wrapper over `QdrantClientConfig` to implement Clone and serde
/// for use in ResourceBuilder
#[derive(Clone, Serialize, Deserialize)]
pub struct QdrantClientConfigWrap {
    url: String,
    api_key: Option<String>,
}

#[async_trait]
impl IntoResource<QdrantClient> for OutputWrapper {
    async fn into_resource(self) -> Result<QdrantClient, Error> {
        let config = match self {
            Self::Container(output) => QdrantClientConfigWrap {
                url: format!("http://localhost:{}", output.output.host_port),
                api_key: None,
            },
            Self::Config(c) => c,
        };
        Ok(QdrantClientConfig::from_url(&config.url)
            .with_api_key(config.api_key)
            .build()?)
    }
}