openssh_sftp_client/sftp/
openssh_session.rs1use std::{fmt::Debug, future::Future, ops::Deref, pin::Pin, sync::Arc};
2
3use openssh::{ChildStdin, ChildStdout, Error as OpensshError, Session, Stdio};
4use tokio::{sync::oneshot, task::JoinHandle};
5
6use crate::{utils::ErrorExt, Error, Sftp, SftpAuxiliaryData, SftpOptions};
7
8#[derive(Debug)]
10pub struct OpensshSession(JoinHandle<Option<Error>>);
11
12pub trait CheckOpensshConnection {
14 fn check_connection<'session>(
17 self: Box<Self>,
18 session: &'session Session,
19 ) -> Pin<Box<dyn Future<Output = Result<(), OpensshError>> + Send + Sync + 'session>>;
20}
21
22impl<F> CheckOpensshConnection for F
23where
24 F: for<'session> FnOnce(
25 &'session Session,
26 ) -> Pin<
27 Box<dyn Future<Output = Result<(), OpensshError>> + Send + Sync + 'session>,
28 >,
29{
30 fn check_connection<'session>(
31 self: Box<Self>,
32 session: &'session Session,
33 ) -> Pin<Box<dyn Future<Output = Result<(), OpensshError>> + Send + Sync + 'session>> {
34 (self)(session)
35 }
36}
37
38impl Drop for OpensshSession {
39 fn drop(&mut self) {
40 self.0.abort();
41 }
42}
43
44#[cfg_attr(
45 feature = "tracing",
46 tracing::instrument(name = "session_task", skip(tx, check_openssh_connection))
47)]
48async fn create_session_task(
49 session: impl Deref<Target = Session> + Clone + Debug + Send + Sync,
50 tx: oneshot::Sender<Result<(ChildStdin, ChildStdout), OpensshError>>,
51 check_openssh_connection: Option<Box<dyn CheckOpensshConnection + Send + Sync>>,
52) -> Option<Error> {
53 #[cfg(feature = "tracing")]
54 tracing::info!("Connecting to sftp subsystem, session = {session:?}");
55
56 let res = Session::to_subsystem(session.clone(), "sftp")
57 .stdin(Stdio::piped())
58 .stdout(Stdio::piped())
59 .stderr(Stdio::null())
60 .spawn()
61 .await;
62
63 let mut child = match res {
64 Ok(child) => child,
65 Err(err) => {
66 #[cfg(feature = "tracing")]
67 tracing::error!(
68 "Failed to connect to remote sftp subsystem: {err}, session = {session:?}"
69 );
70
71 tx.send(Err(err)).unwrap(); return None;
73 }
74 };
75
76 #[cfg(feature = "tracing")]
77 tracing::info!("Connection to sftp subsystem established, session = {session:?}");
78
79 let stdin = child.stdin().take().unwrap();
80 let stdout = child.stdout().take().unwrap();
81 tx.send(Ok((stdin, stdout))).unwrap(); let original_error = {
84 let check_conn_future = async {
85 if let Some(checker) = check_openssh_connection {
86 checker
87 .check_connection(&session)
88 .await
89 .err()
90 .map(Error::from)
91 } else {
92 None
93 }
94 };
95
96 let wait_on_child_future = async {
97 match child.wait().await {
98 Ok(exit_status) => {
99 if !exit_status.success() {
100 Some(Error::SftpServerFailure(exit_status))
101 } else {
102 None
103 }
104 }
105 Err(err) => Some(err.into()),
106 }
107 };
108 tokio::pin!(wait_on_child_future);
109
110 tokio::select! {
111 biased;
112
113 original_error = check_conn_future => {
114 let occuring_error = wait_on_child_future.await;
115 match (original_error, occuring_error) {
116 (Some(original_error), Some(occuring_error)) => {
117 Some(original_error.error_on_cleanup(occuring_error))
118 }
119 (Some(err), None) | (None, Some(err)) => Some(err),
120 (None, None) => None,
121 }
122 }
123 original_error = &mut wait_on_child_future => original_error,
124 }
125 };
126
127 #[cfg(feature = "tracing")]
128 if let Some(err) = &original_error {
129 tracing::error!(
130 "Waiting on remote sftp subsystem to exit failed: {err}, session = {session:?}"
131 );
132 }
133
134 original_error
135}
136
137impl Sftp {
138 pub async fn from_session(session: Session, options: SftpOptions) -> Result<Self, Error> {
145 Self::from_session_with_check_connection_inner(session, options, None).await
146 }
147
148 pub async fn from_session_with_check_connection(
177 session: Session,
178 options: SftpOptions,
179 check_openssh_connection: impl CheckOpensshConnection + Send + Sync + 'static,
180 ) -> Result<Self, Error> {
181 Self::from_session_with_check_connection_inner(
182 session,
183 options,
184 Some(Box::new(check_openssh_connection)),
185 )
186 .await
187 }
188
189 async fn from_session_with_check_connection_inner(
190 session: Session,
191 options: SftpOptions,
192 check_openssh_connection: Option<Box<dyn CheckOpensshConnection + Send + Sync>>,
193 ) -> Result<Self, Error> {
194 let (tx, rx) = oneshot::channel();
195
196 Self::from_session_task(
197 options,
198 rx,
199 tokio::spawn(async move {
200 let original_error =
201 create_session_task(&session, tx, check_openssh_connection).await;
202
203 let _session_str = format!("{session:?}");
204 let occuring_error = session.close().await.err().map(Error::from);
205
206 #[cfg(feature = "tracing")]
207 if let Some(err) = &occuring_error {
208 tracing::error!("Closing session failed: {err}, session = {_session_str}");
209 }
210
211 match (original_error, occuring_error) {
212 (Some(original_error), Some(occuring_error)) => {
213 Some(original_error.error_on_cleanup(occuring_error))
214 }
215 (Some(err), None) | (None, Some(err)) => Some(err),
216 (None, None) => None,
217 }
218 }),
219 )
220 .await
221 }
222
223 pub async fn from_clonable_session(
226 session: impl Deref<Target = Session> + Clone + Debug + Send + Sync + 'static,
227 options: SftpOptions,
228 ) -> Result<Self, Error> {
229 Self::from_clonable_session_with_check_connection_inner(session, options, None).await
230 }
231
232 pub async fn from_clonable_session_with_check_connection(
261 session: impl Deref<Target = Session> + Clone + Debug + Send + Sync + 'static,
262 options: SftpOptions,
263 check_openssh_connection: impl CheckOpensshConnection + Send + Sync + 'static,
264 ) -> Result<Self, Error> {
265 Self::from_clonable_session_with_check_connection_inner(
266 session,
267 options,
268 Some(Box::new(check_openssh_connection)),
269 )
270 .await
271 }
272
273 async fn from_clonable_session_with_check_connection_inner(
274 session: impl Deref<Target = Session> + Clone + Debug + Send + Sync + 'static,
275 options: SftpOptions,
276 check_openssh_connection: Option<Box<dyn CheckOpensshConnection + Send + Sync>>,
277 ) -> Result<Self, Error> {
278 let (tx, rx) = oneshot::channel();
279
280 Self::from_session_task(
281 options,
282 rx,
283 tokio::spawn(create_session_task(session, tx, check_openssh_connection)),
284 )
285 .await
286 }
287
288 async fn from_session_task(
289 options: SftpOptions,
290 rx: oneshot::Receiver<Result<(ChildStdin, ChildStdout), OpensshError>>,
291 handle: JoinHandle<Option<Error>>,
292 ) -> Result<Self, Error> {
293 let msg = "Task failed without sending anything, so it must have panicked";
294
295 let (stdin, stdout) = match rx.await {
296 Ok(res) => res?,
297 Err(_) => return Err(handle.await.expect_err(msg).into()),
298 };
299
300 Self::new_with_auxiliary(
301 stdin,
302 stdout,
303 options,
304 SftpAuxiliaryData::ArcedOpensshSession(Arc::new(OpensshSession(handle))),
305 )
306 .await
307 }
308}
309
310impl OpensshSession {
311 pub(super) async fn recover_session_err(mut self) -> Result<(), Error> {
312 if let Some(err) = (&mut self.0).await? {
313 Err(err)
314 } else {
315 Ok(())
316 }
317 }
318}