use std::{
any::Any,
ffi::{CStr, CString},
marker::PhantomData,
ops::Deref,
os::raw::c_char,
ptr::NonNull,
sync::Arc
};
use crate::{
AsPointer, char_p_to_string,
environment::Environment,
error::{Error, ErrorCode, Result, assert_non_null_pointer, status_to_result},
extern_system_fn,
io_binding::IoBinding,
memory::Allocator,
metadata::ModelMetadata,
ortsys,
value::{Value, ValueType}
};
mod r#async;
pub mod builder;
pub mod input;
pub mod output;
pub mod run_options;
pub use self::{
r#async::InferenceFut,
input::{SessionInputValue, SessionInputs},
output::SessionOutputs,
run_options::{HasSelectedOutputs, NoSelectedOutputs, RunOptions, SelectedOutputMarker}
};
use self::{
r#async::{AsyncInferenceContext, InferenceFutInner, RunOptionsRef},
builder::SessionBuilder
};
#[derive(Debug)]
pub struct SharedSessionInner {
session_ptr: NonNull<ort_sys::OrtSession>,
pub(crate) allocator: Allocator,
_extras: Vec<Box<dyn Any>>,
_environment: Arc<Environment>
}
unsafe impl Send for SharedSessionInner {}
unsafe impl Sync for SharedSessionInner {}
impl AsPointer for SharedSessionInner {
type Sys = ort_sys::OrtSession;
fn ptr(&self) -> *const Self::Sys {
self.session_ptr.as_ptr()
}
}
impl Drop for SharedSessionInner {
fn drop(&mut self) {
tracing::debug!(ptr = ?self.session_ptr.as_ptr(), "dropping SharedSessionInner");
ortsys![unsafe ReleaseSession(self.session_ptr.as_ptr())];
}
}
#[derive(Debug)]
pub struct Session {
pub(crate) inner: Arc<SharedSessionInner>,
pub inputs: Vec<Input>,
pub outputs: Vec<Output>
}
pub struct InMemorySession<'s> {
session: Session,
phantom: PhantomData<&'s ()>
}
impl Deref for InMemorySession<'_> {
type Target = Session;
fn deref(&self) -> &Self::Target {
&self.session
}
}
#[derive(Debug)]
pub struct Input {
pub name: String,
pub input_type: ValueType
}
#[derive(Debug)]
pub struct Output {
pub name: String,
pub output_type: ValueType
}
impl Session {
pub fn builder() -> Result<SessionBuilder> {
SessionBuilder::new()
}
#[must_use]
pub fn allocator(&self) -> &Allocator {
&self.inner.allocator
}
pub fn create_binding(&self) -> Result<IoBinding> {
IoBinding::new(self)
}
#[must_use]
pub fn inner(&self) -> Arc<SharedSessionInner> {
Arc::clone(&self.inner)
}
#[must_use]
pub fn overridable_initializers(&self) -> Vec<OverridableInitializer> {
let mut size = 0;
ortsys![unsafe SessionGetOverridableInitializerCount(self.ptr(), &mut size).expect("infallible")];
let allocator = Allocator::default();
(0..size)
.map(|i| {
let mut name: *mut c_char = std::ptr::null_mut();
ortsys![unsafe SessionGetOverridableInitializerName(self.ptr(), i, allocator.ptr().cast_mut(), &mut name).expect("infallible")];
let name = unsafe { CStr::from_ptr(name) }.to_string_lossy().into_owned();
let mut typeinfo_ptr: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut();
ortsys![unsafe SessionGetOverridableInitializerTypeInfo(self.ptr(), i, &mut typeinfo_ptr).expect("infallible")];
let dtype = ValueType::from_type_info(typeinfo_ptr);
OverridableInitializer { name, dtype }
})
.collect()
}
pub fn run<'s, 'i, 'v: 'i, const N: usize>(&'s self, input_values: impl Into<SessionInputs<'i, 'v, N>>) -> Result<SessionOutputs<'s, 's>> {
match input_values.into() {
SessionInputs::ValueSlice(input_values) => {
self.run_inner::<NoSelectedOutputs>(&self.inputs.iter().map(|input| input.name.as_str()).collect::<Vec<_>>(), input_values.iter(), None)
}
SessionInputs::ValueArray(input_values) => {
self.run_inner::<NoSelectedOutputs>(&self.inputs.iter().map(|input| input.name.as_str()).collect::<Vec<_>>(), input_values.iter(), None)
}
SessionInputs::ValueMap(input_values) => self.run_inner::<NoSelectedOutputs>(
&input_values.iter().map(|(k, _)| k.as_ref()).collect::<Vec<_>>(),
input_values.iter().map(|(_, v)| v),
None
)
}
}
pub fn run_with_options<'r, 's: 'r, 'i, 'v: 'i, O: SelectedOutputMarker, const N: usize>(
&'s self,
input_values: impl Into<SessionInputs<'i, 'v, N>>,
run_options: &'r RunOptions<O>
) -> Result<SessionOutputs<'r, 's>> {
match input_values.into() {
SessionInputs::ValueSlice(input_values) => {
self.run_inner(&self.inputs.iter().map(|input| input.name.as_str()).collect::<Vec<_>>(), input_values.iter(), Some(run_options))
}
SessionInputs::ValueArray(input_values) => {
self.run_inner(&self.inputs.iter().map(|input| input.name.as_str()).collect::<Vec<_>>(), input_values.iter(), Some(run_options))
}
SessionInputs::ValueMap(input_values) => {
self.run_inner(&input_values.iter().map(|(k, _)| k.as_ref()).collect::<Vec<_>>(), input_values.iter().map(|(_, v)| v), Some(run_options))
}
}
}
fn run_inner<'i, 'r, 's: 'r, 'v: 'i, O: SelectedOutputMarker>(
&'s self,
input_names: &[&str],
input_values: impl Iterator<Item = &'i SessionInputValue<'v>>,
run_options: Option<&'r RunOptions<O>>
) -> Result<SessionOutputs<'r, 's>> {
let input_names_ptr: Vec<*const c_char> = input_names
.iter()
.map(|n| CString::new(n.as_bytes()).unwrap_or_else(|_| unreachable!()))
.map(|n| n.into_raw().cast_const())
.collect();
let (output_names, mut output_tensors) = match run_options {
Some(r) => r.outputs.resolve_outputs(&self.outputs),
None => (self.outputs.iter().map(|o| o.name.as_str()).collect(), std::iter::repeat_with(|| None).take(self.outputs.len()).collect())
};
let output_names_ptr: Vec<*const c_char> = output_names
.iter()
.map(|n| CString::new(*n).unwrap_or_else(|_| unreachable!()))
.map(|n| n.into_raw().cast_const())
.collect();
let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = output_tensors
.iter_mut()
.map(|c| match c {
Some(v) => v.ptr_mut(),
None => std::ptr::null_mut()
})
.collect();
let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr()).collect();
if input_ort_values.len() > input_names.len() {
return Err(Error::new_with_code(
ErrorCode::InvalidArgument,
format!("{} inputs were provided, but the model only accepts {}.", input_ort_values.len(), input_names.len())
));
}
let run_options_ptr = if let Some(run_options) = &run_options { run_options.ptr() } else { std::ptr::null() };
ortsys![
unsafe Run(
self.inner.session_ptr.as_ptr(),
run_options_ptr,
input_names_ptr.as_ptr(),
input_ort_values.as_ptr(),
input_ort_values.len(),
output_names_ptr.as_ptr(),
output_names_ptr.len(),
output_tensor_ptrs.as_mut_ptr()
)?
];
let outputs: Vec<Value> = output_tensors
.into_iter()
.enumerate()
.map(|(i, v)| match v {
Some(value) => value,
None => unsafe {
Value::from_ptr(
NonNull::new(output_tensor_ptrs[i]).expect("OrtValue ptr returned from session Run should not be null"),
Some(Arc::clone(&self.inner))
)
}
})
.collect();
for p in input_names_ptr.into_iter().chain(output_names_ptr.into_iter()) {
drop(unsafe { CString::from_raw(p.cast_mut().cast()) });
}
Ok(SessionOutputs::new(output_names, outputs))
}
pub fn run_async<'s, 'i, 'v: 'i + 's, const N: usize>(
&'s self,
input_values: impl Into<SessionInputs<'i, 'v, N>> + 'static
) -> Result<InferenceFut<'s, 's, NoSelectedOutputs>> {
match input_values.into() {
SessionInputs::ValueSlice(_) => unimplemented!("slices cannot be used in `run_async`"),
SessionInputs::ValueArray(input_values) => {
self.run_inner_async(&self.inputs.iter().map(|input| input.name.to_string()).collect::<Vec<_>>(), input_values.into_iter(), None)
}
SessionInputs::ValueMap(input_values) => {
self.run_inner_async(&input_values.iter().map(|(k, _)| k.to_string()).collect::<Vec<_>>(), input_values.into_iter().map(|(_, v)| v), None)
}
}
}
pub fn run_async_with_options<'s, 'i, 'v: 'i + 's, 'r, O: SelectedOutputMarker, const N: usize>(
&'s self,
input_values: impl Into<SessionInputs<'i, 'v, N>> + 'static,
run_options: &'r RunOptions<O>
) -> Result<InferenceFut<'s, 'r, O>> {
match input_values.into() {
SessionInputs::ValueSlice(_) => unimplemented!("slices cannot be used in `run_async`"),
SessionInputs::ValueArray(input_values) => {
self.run_inner_async(&self.inputs.iter().map(|input| input.name.to_string()).collect::<Vec<_>>(), input_values.into_iter(), Some(run_options))
}
SessionInputs::ValueMap(input_values) => self.run_inner_async(
&input_values.iter().map(|(k, _)| k.to_string()).collect::<Vec<_>>(),
input_values.into_iter().map(|(_, v)| v),
Some(run_options)
)
}
}
fn run_inner_async<'s, 'v: 's, 'r, O: SelectedOutputMarker>(
&'s self,
input_names: &[String],
input_values: impl Iterator<Item = SessionInputValue<'v>>,
run_options: Option<&'r RunOptions<O>>
) -> Result<InferenceFut<'s, 'r, O>> {
let run_options = match run_options {
Some(r) => RunOptionsRef::Ref(r),
None => RunOptionsRef::Arc(Arc::new(unsafe {
std::mem::transmute::<RunOptions<NoSelectedOutputs>, RunOptions<O>>(RunOptions::new()?)
}))
};
let input_name_ptrs: Vec<*const c_char> = input_names
.iter()
.map(|n| CString::new(n.as_bytes()).unwrap_or_else(|_| unreachable!()))
.map(|n| n.into_raw().cast_const())
.collect();
let output_name_ptrs: Vec<*const c_char> = self
.outputs
.iter()
.map(|output| CString::new(output.name.as_str()).unwrap_or_else(|_| unreachable!()))
.map(|n| n.into_raw().cast_const())
.collect();
let output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.outputs.len()];
let input_values: Vec<_> = input_values.collect();
let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.iter().map(|input_array_ort| input_array_ort.ptr()).collect();
let async_inner = Arc::new(InferenceFutInner::new());
let ctx = Box::leak(Box::new(AsyncInferenceContext {
inner: Arc::clone(&async_inner),
_input_values: input_values,
input_ort_values,
input_name_ptrs,
output_name_ptrs,
output_names: self.outputs.iter().map(|o| o.name.as_str()).collect::<Vec<_>>(),
output_value_ptrs: output_tensor_ptrs,
session_inner: &self.inner
}));
ortsys![
unsafe RunAsync(
self.inner.session_ptr.as_ptr(),
run_options.ptr(),
ctx.input_name_ptrs.as_ptr(),
ctx.input_ort_values.as_ptr(),
ctx.input_ort_values.len(),
ctx.output_name_ptrs.as_ptr(),
ctx.output_name_ptrs.len(),
ctx.output_value_ptrs.as_mut_ptr(),
Some(self::r#async::async_callback),
ctx as *mut _ as *mut ort_sys::c_void
)?
];
Ok(InferenceFut::new(async_inner, run_options))
}
pub fn metadata(&self) -> Result<ModelMetadata<'_>> {
let mut metadata_ptr: *mut ort_sys::OrtModelMetadata = std::ptr::null_mut();
ortsys![unsafe SessionGetModelMetadata(self.inner.session_ptr.as_ptr(), &mut metadata_ptr)?; nonNull(metadata_ptr)];
Ok(ModelMetadata::new(unsafe { NonNull::new_unchecked(metadata_ptr) }, &self.inner.allocator))
}
pub fn end_profiling(&self) -> Result<String> {
let mut profiling_name: *mut c_char = std::ptr::null_mut();
ortsys![unsafe SessionEndProfiling(self.inner.session_ptr.as_ptr(), self.inner.allocator.ptr().cast_mut(), &mut profiling_name)];
assert_non_null_pointer(profiling_name, "ProfilingName")?;
dangerous::raw_pointer_to_string(&self.inner.allocator, profiling_name)
}
pub fn set_workload_type(&self, workload_type: WorkloadType) -> Result<()> {
static KEY: &[u8] = b"ep.dynamic.workload_type\0";
match workload_type {
WorkloadType::Default => self.set_dynamic_option(KEY.as_ptr().cast(), b"Default\0".as_ptr().cast()),
WorkloadType::Efficient => self.set_dynamic_option(KEY.as_ptr().cast(), b"Efficient\0".as_ptr().cast())
}
}
pub(crate) fn set_dynamic_option(&self, key: *const c_char, value: *const c_char) -> Result<()> {
ortsys![unsafe SetEpDynamicOptions(self.inner.session_ptr.as_ptr(), &key, &value, 1)?];
Ok(())
}
}
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
pub enum WorkloadType {
#[default]
Default,
Efficient
}
unsafe impl Send for Session {}
unsafe impl Sync for Session {}
impl AsPointer for Session {
type Sys = ort_sys::OrtSession;
fn ptr(&self) -> *const Self::Sys {
self.inner.ptr()
}
}
#[derive(Debug, Clone)]
pub struct OverridableInitializer {
name: String,
dtype: ValueType
}
impl OverridableInitializer {
pub fn name(&self) -> &str {
&self.name
}
pub fn dtype(&self) -> &ValueType {
&self.dtype
}
}
mod dangerous {
use super::*;
pub(super) fn extract_inputs_count(session_ptr: NonNull<ort_sys::OrtSession>) -> Result<usize> {
let f = ortsys![SessionGetInputCount];
extract_io_count(f, session_ptr)
}
pub(super) fn extract_outputs_count(session_ptr: NonNull<ort_sys::OrtSession>) -> Result<usize> {
let f = ortsys![SessionGetOutputCount];
extract_io_count(f, session_ptr)
}
fn extract_io_count(
f: extern_system_fn! { unsafe fn(*const ort_sys::OrtSession, *mut usize) -> *mut ort_sys::OrtStatus },
session_ptr: NonNull<ort_sys::OrtSession>
) -> Result<usize> {
let mut num_nodes = 0;
let status = unsafe { f(session_ptr.as_ptr(), &mut num_nodes) };
status_to_result(status)?;
Ok(num_nodes)
}
fn extract_input_name(session_ptr: NonNull<ort_sys::OrtSession>, allocator: &Allocator, i: usize) -> Result<String> {
let f = ortsys![SessionGetInputName];
extract_io_name(f, session_ptr, allocator, i)
}
fn extract_output_name(session_ptr: NonNull<ort_sys::OrtSession>, allocator: &Allocator, i: usize) -> Result<String> {
let f = ortsys![SessionGetOutputName];
extract_io_name(f, session_ptr, allocator, i)
}
pub(crate) fn raw_pointer_to_string(allocator: &Allocator, c_str: *mut c_char) -> Result<String> {
let name = match char_p_to_string(c_str) {
Ok(name) => name,
Err(e) => {
unsafe { allocator.free(c_str) };
return Err(e);
}
};
unsafe { allocator.free(c_str) };
Ok(name)
}
fn extract_io_name(
f: extern_system_fn! { unsafe fn(
*const ort_sys::OrtSession,
usize,
*mut ort_sys::OrtAllocator,
*mut *mut c_char,
) -> *mut ort_sys::OrtStatus },
session_ptr: NonNull<ort_sys::OrtSession>,
allocator: &Allocator,
i: usize
) -> Result<String> {
let mut name_bytes: *mut c_char = std::ptr::null_mut();
let status = unsafe { f(session_ptr.as_ptr(), i, allocator.ptr().cast_mut(), &mut name_bytes) };
status_to_result(status)?;
assert_non_null_pointer(name_bytes, "InputName")?;
raw_pointer_to_string(allocator, name_bytes)
}
pub(super) fn extract_input(session_ptr: NonNull<ort_sys::OrtSession>, allocator: &Allocator, i: usize) -> Result<Input> {
let input_name = extract_input_name(session_ptr, allocator, i)?;
let f = ortsys![SessionGetInputTypeInfo];
let input_type = extract_io(f, session_ptr, i)?;
Ok(Input { name: input_name, input_type })
}
pub(super) fn extract_output(session_ptr: NonNull<ort_sys::OrtSession>, allocator: &Allocator, i: usize) -> Result<Output> {
let output_name = extract_output_name(session_ptr, allocator, i)?;
let f = ortsys![SessionGetOutputTypeInfo];
let output_type = extract_io(f, session_ptr, i)?;
Ok(Output { name: output_name, output_type })
}
fn extract_io(
f: extern_system_fn! { unsafe fn(
*const ort_sys::OrtSession,
usize,
*mut *mut ort_sys::OrtTypeInfo,
) -> *mut ort_sys::OrtStatus },
session_ptr: NonNull<ort_sys::OrtSession>,
i: usize
) -> Result<ValueType> {
let mut typeinfo_ptr: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut();
let status = unsafe { f(session_ptr.as_ptr(), i, &mut typeinfo_ptr) };
status_to_result(status)?;
assert_non_null_pointer(typeinfo_ptr, "TypeInfo")?;
Ok(ValueType::from_type_info(typeinfo_ptr))
}
}