openssh_sftp_client/sftp/
openssh_session.rs

1use std::{fmt::Debug, future::Future, ops::Deref, pin::Pin, sync::Arc};
2
3use openssh::{ChildStdin, ChildStdout, Error as OpensshError, Session, Stdio};
4use tokio::{sync::oneshot, task::JoinHandle};
5
6use crate::{utils::ErrorExt, Error, Sftp, SftpAuxiliaryData, SftpOptions};
7
8/// The openssh session
9#[derive(Debug)]
10pub struct OpensshSession(JoinHandle<Option<Error>>);
11
12/// Check for openssh connection to be alive
13pub trait CheckOpensshConnection {
14    /// This function should only return on `Err()`.
15    /// Once the sftp session is closed, the future will be cancelled (dropped).
16    fn check_connection<'session>(
17        self: Box<Self>,
18        session: &'session Session,
19    ) -> Pin<Box<dyn Future<Output = Result<(), OpensshError>> + Send + Sync + 'session>>;
20}
21
22impl<F> CheckOpensshConnection for F
23where
24    F: for<'session> FnOnce(
25        &'session Session,
26    ) -> Pin<
27        Box<dyn Future<Output = Result<(), OpensshError>> + Send + Sync + 'session>,
28    >,
29{
30    fn check_connection<'session>(
31        self: Box<Self>,
32        session: &'session Session,
33    ) -> Pin<Box<dyn Future<Output = Result<(), OpensshError>> + Send + Sync + 'session>> {
34        (self)(session)
35    }
36}
37
38impl Drop for OpensshSession {
39    fn drop(&mut self) {
40        self.0.abort();
41    }
42}
43
44#[cfg_attr(
45    feature = "tracing",
46    tracing::instrument(name = "session_task", skip(tx, check_openssh_connection))
47)]
48async fn create_session_task(
49    session: impl Deref<Target = Session> + Clone + Debug + Send + Sync,
50    tx: oneshot::Sender<Result<(ChildStdin, ChildStdout), OpensshError>>,
51    check_openssh_connection: Option<Box<dyn CheckOpensshConnection + Send + Sync>>,
52) -> Option<Error> {
53    #[cfg(feature = "tracing")]
54    tracing::info!("Connecting to sftp subsystem, session = {session:?}");
55
56    let res = Session::to_subsystem(session.clone(), "sftp")
57        .stdin(Stdio::piped())
58        .stdout(Stdio::piped())
59        .stderr(Stdio::null())
60        .spawn()
61        .await;
62
63    let mut child = match res {
64        Ok(child) => child,
65        Err(err) => {
66            #[cfg(feature = "tracing")]
67            tracing::error!(
68                "Failed to connect to remote sftp subsystem: {err}, session = {session:?}"
69            );
70
71            tx.send(Err(err)).unwrap(); // Err
72            return None;
73        }
74    };
75
76    #[cfg(feature = "tracing")]
77    tracing::info!("Connection to sftp subsystem established, session = {session:?}");
78
79    let stdin = child.stdin().take().unwrap();
80    let stdout = child.stdout().take().unwrap();
81    tx.send(Ok((stdin, stdout))).unwrap(); // Ok
82
83    let original_error = {
84        let check_conn_future = async {
85            if let Some(checker) = check_openssh_connection {
86                checker
87                    .check_connection(&session)
88                    .await
89                    .err()
90                    .map(Error::from)
91            } else {
92                None
93            }
94        };
95
96        let wait_on_child_future = async {
97            match child.wait().await {
98                Ok(exit_status) => {
99                    if !exit_status.success() {
100                        Some(Error::SftpServerFailure(exit_status))
101                    } else {
102                        None
103                    }
104                }
105                Err(err) => Some(err.into()),
106            }
107        };
108        tokio::pin!(wait_on_child_future);
109
110        tokio::select! {
111            biased;
112
113            original_error = check_conn_future => {
114                let occuring_error = wait_on_child_future.await;
115                match (original_error, occuring_error) {
116                    (Some(original_error), Some(occuring_error)) => {
117                        Some(original_error.error_on_cleanup(occuring_error))
118                    }
119                    (Some(err), None) | (None, Some(err)) => Some(err),
120                    (None, None) => None,
121                }
122            }
123            original_error = &mut wait_on_child_future => original_error,
124        }
125    };
126
127    #[cfg(feature = "tracing")]
128    if let Some(err) = &original_error {
129        tracing::error!(
130            "Waiting on remote sftp subsystem to exit failed: {err}, session = {session:?}"
131        );
132    }
133
134    original_error
135}
136
137impl Sftp {
138    /// Create [`Sftp`] from [`openssh::Session`].
139    ///
140    /// Calling [`Sftp::close`] on sftp instances created using this function
141    /// would also await on [`openssh::RemoteChild::wait`] and
142    /// [`openssh::Session::close`] and propagate their error in
143    /// [`Sftp::close`].
144    pub async fn from_session(session: Session, options: SftpOptions) -> Result<Self, Error> {
145        Self::from_session_with_check_connection_inner(session, options, None).await
146    }
147
148    /// Similar to [`Sftp::from_session`], but takes an additional parameter
149    /// for checking if the connection is still alive.
150    ///
151    /// # Example
152    ///
153    /// ```rust,no_run
154    /// fn check_connection<'session>(
155    ///     session: &'session openssh::Session,
156    /// ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), openssh::Error>> + Send + Sync + 'session>> {
157    ///     Box::pin(async move {
158    ///         loop {
159    ///             tokio::time::sleep(std::time::Duration::from_secs(10)).await;
160    ///             session.check().await?;
161    ///         }
162    ///         Ok(())
163    ///     })
164    /// }
165    ///
166    /// # #[tokio::main(flavor = "current_thread")]
167    /// # async fn main() -> Result<(), openssh_sftp_client::Error> {
168    /// openssh_sftp_client::Sftp::from_session_with_check_connection(
169    ///     openssh::Session::connect_mux("me@ssh.example.com", openssh::KnownHosts::Strict).await?,
170    ///     openssh_sftp_client::SftpOptions::default(),
171    ///     check_connection,
172    /// ).await?;
173    /// # Ok(())
174    /// # }
175    /// ```
176    pub async fn from_session_with_check_connection(
177        session: Session,
178        options: SftpOptions,
179        check_openssh_connection: impl CheckOpensshConnection + Send + Sync + 'static,
180    ) -> Result<Self, Error> {
181        Self::from_session_with_check_connection_inner(
182            session,
183            options,
184            Some(Box::new(check_openssh_connection)),
185        )
186        .await
187    }
188
189    async fn from_session_with_check_connection_inner(
190        session: Session,
191        options: SftpOptions,
192        check_openssh_connection: Option<Box<dyn CheckOpensshConnection + Send + Sync>>,
193    ) -> Result<Self, Error> {
194        let (tx, rx) = oneshot::channel();
195
196        Self::from_session_task(
197            options,
198            rx,
199            tokio::spawn(async move {
200                let original_error =
201                    create_session_task(&session, tx, check_openssh_connection).await;
202
203                let _session_str = format!("{session:?}");
204                let occuring_error = session.close().await.err().map(Error::from);
205
206                #[cfg(feature = "tracing")]
207                if let Some(err) = &occuring_error {
208                    tracing::error!("Closing session failed: {err}, session = {_session_str}");
209                }
210
211                match (original_error, occuring_error) {
212                    (Some(original_error), Some(occuring_error)) => {
213                        Some(original_error.error_on_cleanup(occuring_error))
214                    }
215                    (Some(err), None) | (None, Some(err)) => Some(err),
216                    (None, None) => None,
217                }
218            }),
219        )
220        .await
221    }
222
223    /// Create [`Sftp`] from any type that can be dereferenced to [`openssh::Session`]
224    /// and is clonable.
225    pub async fn from_clonable_session(
226        session: impl Deref<Target = Session> + Clone + Debug + Send + Sync + 'static,
227        options: SftpOptions,
228    ) -> Result<Self, Error> {
229        Self::from_clonable_session_with_check_connection_inner(session, options, None).await
230    }
231
232    /// Similar to [`Sftp::from_session_with_check_connection`], but takes an additional parameter
233    /// for checking if the connection is still alive.
234    ///
235    /// # Example
236    ///
237    /// ```rust,no_run
238    /// fn check_connection<'session>(
239    ///     session: &'session openssh::Session,
240    /// ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), openssh::Error>> + Send + Sync + 'session>> {
241    ///     Box::pin(async move {
242    ///         loop {
243    ///             tokio::time::sleep(std::time::Duration::from_secs(10)).await;
244    ///             session.check().await?;
245    ///         }
246    ///         Ok(())
247    ///     })
248    /// }
249    ///
250    /// # #[tokio::main(flavor = "current_thread")]
251    /// # async fn main() -> Result<(), openssh_sftp_client::Error> {
252    /// openssh_sftp_client::Sftp::from_clonable_session_with_check_connection(
253    ///     std::sync::Arc::new(openssh::Session::connect_mux("me@ssh.example.com", openssh::KnownHosts::Strict).await?),
254    ///     openssh_sftp_client::SftpOptions::default(),
255    ///     check_connection,
256    /// ).await?;
257    /// # Ok(())
258    /// # }
259    /// ```
260    pub async fn from_clonable_session_with_check_connection(
261        session: impl Deref<Target = Session> + Clone + Debug + Send + Sync + 'static,
262        options: SftpOptions,
263        check_openssh_connection: impl CheckOpensshConnection + Send + Sync + 'static,
264    ) -> Result<Self, Error> {
265        Self::from_clonable_session_with_check_connection_inner(
266            session,
267            options,
268            Some(Box::new(check_openssh_connection)),
269        )
270        .await
271    }
272
273    async fn from_clonable_session_with_check_connection_inner(
274        session: impl Deref<Target = Session> + Clone + Debug + Send + Sync + 'static,
275        options: SftpOptions,
276        check_openssh_connection: Option<Box<dyn CheckOpensshConnection + Send + Sync>>,
277    ) -> Result<Self, Error> {
278        let (tx, rx) = oneshot::channel();
279
280        Self::from_session_task(
281            options,
282            rx,
283            tokio::spawn(create_session_task(session, tx, check_openssh_connection)),
284        )
285        .await
286    }
287
288    async fn from_session_task(
289        options: SftpOptions,
290        rx: oneshot::Receiver<Result<(ChildStdin, ChildStdout), OpensshError>>,
291        handle: JoinHandle<Option<Error>>,
292    ) -> Result<Self, Error> {
293        let msg = "Task failed without sending anything, so it must have panicked";
294
295        let (stdin, stdout) = match rx.await {
296            Ok(res) => res?,
297            Err(_) => return Err(handle.await.expect_err(msg).into()),
298        };
299
300        Self::new_with_auxiliary(
301            stdin,
302            stdout,
303            options,
304            SftpAuxiliaryData::ArcedOpensshSession(Arc::new(OpensshSession(handle))),
305        )
306        .await
307    }
308}
309
310impl OpensshSession {
311    pub(super) async fn recover_session_err(mut self) -> Result<(), Error> {
312        if let Some(err) = (&mut self.0).await? {
313            Err(err)
314        } else {
315            Ok(())
316        }
317    }
318}