1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
use std::marker::PhantomData;
use std::ops::DerefMut;

use rocket::fairing::{self, Fairing, Info, Kind};
use rocket::http::Status;
use rocket::request::{FromRequest, Outcome, Request};
use rocket::{error, info_, Build, Ignite, Phase, Rocket, Sentinel};

use rocket::figment::providers::Serialized;
use rocket::yansi::Paint;

#[cfg(feature = "rocket_okapi")]
use rocket_okapi::{
    gen::OpenApiGenerator,
    request::{OpenApiFromRequest, RequestHeaderInput},
};

use crate::Pool;

/// Derivable trait which ties a database [`Pool`] with a configuration name.
///
/// This trait should rarely, if ever, be implemented manually. Instead, it
/// should be derived:
///
/// ```ignore
/// use sea_orm_rocket::{Database};
/// # use sea_orm_rocket::MockPool as SeaOrmPool;
///
/// #[derive(Database, Debug)]
/// #[database("sea_orm")]
/// struct Db(SeaOrmPool);
///
/// #[launch]
/// fn rocket() -> _ {
///     rocket::build().attach(Db::init())
/// }
/// ```
///
/// See the [`Database` derive](derive@crate::Database) for details.
pub trait Database:
    From<Self::Pool> + DerefMut<Target = Self::Pool> + Send + Sync + 'static
{
    /// The [`Pool`] type of connections to this database.
    ///
    /// When `Database` is derived, this takes the value of the `Inner` type in
    /// `struct Db(Inner)`.
    type Pool: Pool;

    /// The configuration name for this database.
    ///
    /// When `Database` is derived, this takes the value `"name"` in the
    /// `#[database("name")]` attribute.
    const NAME: &'static str;

    /// Returns a fairing that initializes the database and its connection pool.
    ///
    /// # Example
    ///
    /// ```rust
    /// # mod _inner {
    /// # use rocket::launch;
    /// use sea_orm_rocket::Database;
    /// # use sea_orm_rocket::MockPool as SeaOrmPool;
    ///
    /// #[derive(Database)]
    /// #[database("sea_orm")]
    /// struct Db(SeaOrmPool);
    ///
    /// #[launch]
    /// fn rocket() -> _ {
    ///     rocket::build().attach(Db::init())
    /// }
    /// # }
    /// ```
    fn init() -> Initializer<Self> {
        Initializer::new()
    }

    /// Returns a reference to the initialized database in `rocket`. The
    /// initializer fairing returned by `init()` must have already executed for
    /// `Option` to be `Some`. This is guaranteed to be the case if the fairing
    /// is attached and either:
    ///
    ///   * Rocket is in the [`Orbit`](rocket::Orbit) phase. That is, the
    ///     application is running. This is always the case in request guards
    ///     and liftoff fairings,
    ///   * _or_ Rocket is in the [`Build`](rocket::Build) or
    ///     [`Ignite`](rocket::Ignite) phase and the `Initializer` fairing has
    ///     already been run. This is the case in all fairing callbacks
    ///     corresponding to fairings attached _after_ the `Initializer`
    ///     fairing.
    ///
    /// # Example
    ///
    /// Run database migrations in an ignite fairing. It is imperative that the
    /// migration fairing be registered _after_ the `init()` fairing.
    ///
    /// ```rust
    /// # mod _inner {
    /// # use rocket::launch;
    /// use rocket::fairing::{self, AdHoc};
    /// use rocket::{Build, Rocket};
    ///
    /// use sea_orm_rocket::Database;
    /// # use sea_orm_rocket::MockPool as SeaOrmPool;
    ///
    /// #[derive(Database)]
    /// #[database("sea_orm")]
    /// struct Db(SeaOrmPool);
    ///
    /// async fn run_migrations(rocket: Rocket<Build>) -> fairing::Result {
    ///     if let Some(db) = Db::fetch(&rocket) {
    ///         // run migrations using `db`. get the inner type with &db.0.
    ///         Ok(rocket)
    ///     } else {
    ///         Err(rocket)
    ///     }
    /// }
    ///
    /// #[launch]
    /// fn rocket() -> _ {
    ///     rocket::build()
    ///         .attach(Db::init())
    ///         .attach(AdHoc::try_on_ignite("DB Migrations", run_migrations))
    /// }
    /// # }
    /// ```
    fn fetch<P: Phase>(rocket: &Rocket<P>) -> Option<&Self> {
        if let Some(db) = rocket.state() {
            return Some(db);
        }

        let dbtype = std::any::type_name::<Self>();
        let fairing = Paint::default(format!("{}::init()", dbtype)).bold();
        error!(
            "Attempted to fetch unattached database `{}`.",
            Paint::default(dbtype).bold()
        );
        info_!(
            "`{}` fairing must be attached prior to using this database.",
            fairing
        );
        None
    }
}

/// A [`Fairing`] which initializes a [`Database`] and its connection pool.
///
/// A value of this type can be created for any type `D` that implements
/// [`Database`] via the [`Database::init()`] method on the type. Normally, a
/// value of this type _never_ needs to be constructed directly. This
/// documentation exists purely as a reference.
///
/// This fairing initializes a database pool. Specifically, it:
///
///   1. Reads the configuration at `database.db_name`, where `db_name` is
///      [`Database::NAME`].
///
///   2. Sets [`Config`](crate::Config) defaults on the configuration figment.
///
///   3. Calls [`Pool::init()`].
///
///   4. Stores the database instance in managed storage, retrievable via
///      [`Database::fetch()`].
///
/// The name of the fairing itself is `Initializer<D>`, with `D` replaced with
/// the type name `D` unless a name is explicitly provided via
/// [`Self::with_name()`].
pub struct Initializer<D: Database>(Option<&'static str>, PhantomData<fn() -> D>);

/// A request guard which retrieves a single connection to a [`Database`].
///
/// For a database type of `Db`, a request guard of `Connection<Db>` retrieves a
/// single connection to `Db`.
///
/// The request guard succeeds if the database was initialized by the
/// [`Initializer`] fairing and a connection is available within
/// [`connect_timeout`](crate::Config::connect_timeout) seconds.
///   * If the `Initializer` fairing was _not_ attached, the guard _fails_ with
///   status `InternalServerError`. A [`Sentinel`] guards this condition, and so
///   this type of failure is unlikely to occur. A `None` error is returned.
///   * If a connection is not available within `connect_timeout` seconds or
///   another error occurs, the gaurd _fails_ with status `ServiceUnavailable`
///   and the error is returned in `Some`.
pub struct Connection<'a, D: Database>(&'a <D::Pool as Pool>::Connection);

impl<D: Database> Initializer<D> {
    /// Returns a database initializer fairing for `D`.
    ///
    /// This method should never need to be called manually. See the [crate
    /// docs](crate) for usage information.
    #[allow(clippy::new_without_default)]
    pub fn new() -> Self {
        Self(None, std::marker::PhantomData)
    }

    /// Returns a database initializer fairing for `D` with name `name`.
    ///
    /// This method should never need to be called manually. See the [crate
    /// docs](crate) for usage information.
    pub fn with_name(name: &'static str) -> Self {
        Self(Some(name), std::marker::PhantomData)
    }
}

impl<'a, D: Database> Connection<'a, D> {
    /// Returns the internal connection value. See the [`Connection` Deref
    /// column](crate#supported-drivers) for the expected type of this value.
    pub fn into_inner(self) -> &'a <D::Pool as Pool>::Connection {
        self.0
    }
}

#[cfg(feature = "rocket_okapi")]
impl<'r, D: Database> OpenApiFromRequest<'r> for Connection<'r, D> {
    fn from_request_input(
        _gen: &mut OpenApiGenerator,
        _name: String,
        _required: bool,
    ) -> rocket_okapi::Result<RequestHeaderInput> {
        Ok(RequestHeaderInput::None)
    }
}

#[rocket::async_trait]
impl<D: Database> Fairing for Initializer<D> {
    fn info(&self) -> Info {
        Info {
            name: self.0.unwrap_or_else(std::any::type_name::<Self>),
            kind: Kind::Ignite,
        }
    }

    async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result {
        let workers: usize = rocket
            .figment()
            .extract_inner(rocket::Config::WORKERS)
            .unwrap_or_else(|_| rocket::Config::default().workers);

        let figment = rocket
            .figment()
            .focus(&format!("databases.{}", D::NAME))
            .merge(Serialized::default("max_connections", workers * 4))
            .merge(Serialized::default("connect_timeout", 5))
            .merge(Serialized::default("sqlx_logging", true));

        match <D::Pool>::init(&figment).await {
            Ok(pool) => Ok(rocket.manage(D::from(pool))),
            Err(e) => {
                error!("failed to initialize database: {}", e);
                Err(rocket)
            }
        }
    }
}

#[rocket::async_trait]
impl<'r, D: Database> FromRequest<'r> for Connection<'r, D> {
    type Error = Option<<D::Pool as Pool>::Error>;

    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
        match D::fetch(req.rocket()) {
            Some(pool) => Outcome::Success(Connection(pool.borrow())),
            None => Outcome::Failure((Status::InternalServerError, None)),
        }
    }
}

impl<D: Database> Sentinel for Connection<'_, D> {
    fn abort(rocket: &Rocket<Ignite>) -> bool {
        D::fetch(rocket).is_none()
    }
}