sqlite_vfs_http/
vfs.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
use super::*;
use rand::{thread_rng, Rng};
use sqlite_vfs::{OpenKind, OpenOptions, Vfs};
use std::{
    io::{Error, ErrorKind},
    time::Duration,
};

pub const HTTP_VFS: &str = "http";

pub struct HttpVfs {
    pub(crate) block_size: usize,
    pub(crate) download_threshold: usize,
}

impl Vfs for HttpVfs {
    type Handle = Connection;

    fn open(&self, db: &str, opts: OpenOptions) -> Result<Self::Handle, Error> {
        if opts.kind != OpenKind::MainDb {
            return Err(Error::new(
                ErrorKind::ReadOnlyFilesystem,
                "only main database supported",
            ));
        }

        Ok(Connection::new(
            db,
            self.block_size,
            self.download_threshold,
        )?)
    }

    fn delete(&self, _db: &str) -> Result<(), Error> {
        Err(Error::new(
            ErrorKind::ReadOnlyFilesystem,
            "delete operation is not supported",
        ))
    }

    fn exists(&self, _db: &str) -> Result<bool, Error> {
        Ok(false)
    }

    fn temporary_name(&self) -> String {
        String::from("main.db")
    }

    fn random(&self, buffer: &mut [i8]) {
        Rng::fill(&mut thread_rng(), buffer);
    }

    fn sleep(&self, duration: Duration) -> Duration {
        std::thread::sleep(duration);
        duration
    }
}

#[cfg(test)]
mod tests {
    use std::future::Future;

    use super::*;
    use rusqlite::{Connection, OpenFlags};
    use tokio::time::sleep;

    const QUERY_SQLITE_MASTER: &str = "SELECT count(1) FROM sqlite_master WHERE type = 'table'";
    const QUERY_TEST: &str = "SELECT name FROM test";

    mod server {
        use rocket::{custom, figment::Figment, get, routes, Config, Shutdown, State};
        use rocket_seek_stream::SeekStream;
        use rusqlite::Connection;
        use std::{collections::HashMap, fs::read, io::Cursor, thread::JoinHandle};
        use tempfile::tempdir;
        use tokio::runtime::Runtime;

        fn init_database() -> HashMap<i64, Vec<u8>> {
            let schemas = [
                vec![
                    "PRAGMA journal_mode = MEMORY;",
                    "CREATE TABLE test1 (id INTEGER PRIMARY KEY, name TEXT);",
                    "CREATE TABLE test2 (id INTEGER PRIMARY KEY, name TEXT);",
                ],
                vec![
                    "PRAGMA journal_mode = MEMORY;",
                    "CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT);",
                    "INSERT INTO test (name) VALUES ('Alice');",
                    "INSERT INTO test (name) VALUES ('Bob');",
                ],
            ];
            let mut database = HashMap::new();

            let temp = tempdir().unwrap();

            for (i, schema) in schemas.into_iter().enumerate() {
                let path = temp.path().join(format!("{i}.db"));
                let conn = Connection::open(&path).unwrap();
                conn.execute_batch(&schema.join("\n")).unwrap();
                conn.close().unwrap();
                database.insert(i as i64, read(&path).unwrap());
            }

            database
        }

        #[get("/<id>")]
        pub async fn database(
            db: &State<HashMap<i64, Vec<u8>>>,
            id: i64,
        ) -> Option<SeekStream<'static>> {
            if let Some(buffer) = db.get(&id) {
                let cursor = Cursor::new(buffer.clone());
                Some(SeekStream::with_opts(cursor, buffer.len() as u64, None))
            } else {
                None
            }
        }

        #[get("/shutdown")]
        pub async fn shutdown(shutdown: Shutdown) -> &'static str {
            shutdown.notify();
            "Shutting down..."
        }

        pub fn launch() -> JoinHandle<Result<(), rocket::Error>> {
            std::thread::spawn(|| {
                let rt = Runtime::new().unwrap();
                rt.block_on(async {
                    custom(Figment::from(Config::default()).merge(("port", 4096)))
                        .manage(init_database())
                        .mount("/", routes![database, shutdown])
                        .launch()
                        .await?;

                    Ok(())
                })
            })
        }
    }

    async fn init_server<C, F>(future: C) -> anyhow::Result<()>
    where
        C: FnOnce(String) -> F,
        F: Future<Output = anyhow::Result<()>>,
    {
        let base = "http://127.0.0.1:4096";
        let server = server::launch();

        // wait for server to start
        loop {
            let resp = reqwest::get(base).await;
            if let Ok(resp) = resp {
                if resp.status() == 404 {
                    break;
                }
            }
            sleep(Duration::from_millis(100)).await;
        }

        future(base.into()).await?;

        reqwest::get(format!("{base}/shutdown").as_str()).await?;
        server.join().unwrap()?;

        Ok(())
    }

    #[tokio::test]
    async fn test_http_vfs() {
        init_server(|base| async move {
            vfs::register_http_vfs();

            {
                let conn = Connection::open_with_flags_and_vfs(
                    format!("{base}/0"),
                    OpenFlags::SQLITE_OPEN_READ_WRITE
                        | OpenFlags::SQLITE_OPEN_CREATE
                        | OpenFlags::SQLITE_OPEN_NO_MUTEX,
                    HTTP_VFS,
                )?;
                assert_eq!(
                    conn.query_row::<usize, _, _>(QUERY_SQLITE_MASTER, [], |row| row.get(0))?,
                    2
                );
            }

            {
                let conn = Connection::open_with_flags_and_vfs(
                    format!("{base}/1"),
                    OpenFlags::SQLITE_OPEN_READ_WRITE
                        | OpenFlags::SQLITE_OPEN_CREATE
                        | OpenFlags::SQLITE_OPEN_NO_MUTEX,
                    HTTP_VFS,
                )?;
                let mut stmt = conn.prepare(QUERY_TEST)?;
                assert_eq!(
                    stmt.query_map([], |row| row.get::<_, String>(0))?
                        .collect::<Result<Vec<_>, _>>()?,
                    vec!["Alice".to_string(), "Bob".to_string()]
                );
            }

            Ok(())
        })
        .await
        .unwrap();
    }
}