openssh_sftp_client/sftp/
openssh_session.rsuse std::{fmt::Debug, future::Future, ops::Deref, pin::Pin, sync::Arc};
use openssh::{ChildStdin, ChildStdout, Error as OpensshError, Session, Stdio};
use tokio::{sync::oneshot, task::JoinHandle};
use crate::{utils::ErrorExt, Error, Sftp, SftpAuxiliaryData, SftpOptions};
#[derive(Debug)]
pub struct OpensshSession(JoinHandle<Option<Error>>);
pub trait CheckOpensshConnection {
fn check_connection<'session>(
self: Box<Self>,
session: &'session Session,
) -> Pin<Box<dyn Future<Output = Result<(), OpensshError>> + Send + Sync + 'session>>;
}
impl<F> CheckOpensshConnection for F
where
F: for<'session> FnOnce(
&'session Session,
) -> Pin<
Box<dyn Future<Output = Result<(), OpensshError>> + Send + Sync + 'session>,
>,
{
fn check_connection<'session>(
self: Box<Self>,
session: &'session Session,
) -> Pin<Box<dyn Future<Output = Result<(), OpensshError>> + Send + Sync + 'session>> {
(self)(session)
}
}
impl Drop for OpensshSession {
fn drop(&mut self) {
self.0.abort();
}
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(name = "session_task", skip(tx, check_openssh_connection))
)]
async fn create_session_task(
session: impl Deref<Target = Session> + Clone + Debug + Send + Sync,
tx: oneshot::Sender<Result<(ChildStdin, ChildStdout), OpensshError>>,
check_openssh_connection: Option<Box<dyn CheckOpensshConnection + Send + Sync>>,
) -> Option<Error> {
#[cfg(feature = "tracing")]
tracing::info!("Connecting to sftp subsystem, session = {session:?}");
let res = Session::to_subsystem(session.clone(), "sftp")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::null())
.spawn()
.await;
let mut child = match res {
Ok(child) => child,
Err(err) => {
#[cfg(feature = "tracing")]
tracing::error!(
"Failed to connect to remote sftp subsystem: {err}, session = {session:?}"
);
tx.send(Err(err)).unwrap(); return None;
}
};
#[cfg(feature = "tracing")]
tracing::info!("Connection to sftp subsystem established, session = {session:?}");
let stdin = child.stdin().take().unwrap();
let stdout = child.stdout().take().unwrap();
tx.send(Ok((stdin, stdout))).unwrap(); let original_error = {
let check_conn_future = async {
if let Some(checker) = check_openssh_connection {
checker
.check_connection(&session)
.await
.err()
.map(Error::from)
} else {
None
}
};
let wait_on_child_future = async {
match child.wait().await {
Ok(exit_status) => {
if !exit_status.success() {
Some(Error::SftpServerFailure(exit_status))
} else {
None
}
}
Err(err) => Some(err.into()),
}
};
tokio::pin!(wait_on_child_future);
tokio::select! {
biased;
original_error = check_conn_future => {
let occuring_error = wait_on_child_future.await;
match (original_error, occuring_error) {
(Some(original_error), Some(occuring_error)) => {
Some(original_error.error_on_cleanup(occuring_error))
}
(Some(err), None) | (None, Some(err)) => Some(err),
(None, None) => None,
}
}
original_error = &mut wait_on_child_future => original_error,
}
};
#[cfg(feature = "tracing")]
if let Some(err) = &original_error {
tracing::error!(
"Waiting on remote sftp subsystem to exit failed: {err}, session = {session:?}"
);
}
original_error
}
impl Sftp {
pub async fn from_session(session: Session, options: SftpOptions) -> Result<Self, Error> {
Self::from_session_with_check_connection_inner(session, options, None).await
}
pub async fn from_session_with_check_connection(
session: Session,
options: SftpOptions,
check_openssh_connection: impl CheckOpensshConnection + Send + Sync + 'static,
) -> Result<Self, Error> {
Self::from_session_with_check_connection_inner(
session,
options,
Some(Box::new(check_openssh_connection)),
)
.await
}
async fn from_session_with_check_connection_inner(
session: Session,
options: SftpOptions,
check_openssh_connection: Option<Box<dyn CheckOpensshConnection + Send + Sync>>,
) -> Result<Self, Error> {
let (tx, rx) = oneshot::channel();
Self::from_session_task(
options,
rx,
tokio::spawn(async move {
let original_error =
create_session_task(&session, tx, check_openssh_connection).await;
let _session_str = format!("{session:?}");
let occuring_error = session.close().await.err().map(Error::from);
#[cfg(feature = "tracing")]
if let Some(err) = &occuring_error {
tracing::error!("Closing session failed: {err}, session = {_session_str}");
}
match (original_error, occuring_error) {
(Some(original_error), Some(occuring_error)) => {
Some(original_error.error_on_cleanup(occuring_error))
}
(Some(err), None) | (None, Some(err)) => Some(err),
(None, None) => None,
}
}),
)
.await
}
pub async fn from_clonable_session(
session: impl Deref<Target = Session> + Clone + Debug + Send + Sync + 'static,
options: SftpOptions,
) -> Result<Self, Error> {
Self::from_clonable_session_with_check_connection_inner(session, options, None).await
}
pub async fn from_clonable_session_with_check_connection(
session: impl Deref<Target = Session> + Clone + Debug + Send + Sync + 'static,
options: SftpOptions,
check_openssh_connection: impl CheckOpensshConnection + Send + Sync + 'static,
) -> Result<Self, Error> {
Self::from_clonable_session_with_check_connection_inner(
session,
options,
Some(Box::new(check_openssh_connection)),
)
.await
}
async fn from_clonable_session_with_check_connection_inner(
session: impl Deref<Target = Session> + Clone + Debug + Send + Sync + 'static,
options: SftpOptions,
check_openssh_connection: Option<Box<dyn CheckOpensshConnection + Send + Sync>>,
) -> Result<Self, Error> {
let (tx, rx) = oneshot::channel();
Self::from_session_task(
options,
rx,
tokio::spawn(create_session_task(session, tx, check_openssh_connection)),
)
.await
}
async fn from_session_task(
options: SftpOptions,
rx: oneshot::Receiver<Result<(ChildStdin, ChildStdout), OpensshError>>,
handle: JoinHandle<Option<Error>>,
) -> Result<Self, Error> {
let msg = "Task failed without sending anything, so it must have panicked";
let (stdin, stdout) = match rx.await {
Ok(res) => res?,
Err(_) => return Err(handle.await.expect_err(msg).into()),
};
Self::new_with_auxiliary(
stdin,
stdout,
options,
SftpAuxiliaryData::ArcedOpensshSession(Arc::new(OpensshSession(handle))),
)
.await
}
}
impl OpensshSession {
pub(super) async fn recover_session_err(mut self) -> Result<(), Error> {
if let Some(err) = (&mut self.0).await? {
Err(err)
} else {
Ok(())
}
}
}