intuicio_ffi/
lib.rs

1use intuicio_core::{
2    context::Context,
3    function::{Function, FunctionBody, FunctionQuery, FunctionSignature},
4    types::TypeHandle,
5};
6use libffi::raw::{
7    ffi_abi_FFI_DEFAULT_ABI, ffi_call, ffi_cif, ffi_prep_cif, ffi_type, ffi_type_void,
8    FFI_TYPE_STRUCT,
9};
10use libloading::Library;
11use std::{
12    error::Error,
13    ffi::{c_void, OsString},
14    path::Path,
15    ptr::null_mut,
16    str::FromStr,
17    sync::Arc,
18};
19
20pub use libffi::low::CodePtr as FfiCodePtr;
21
22pub type FfiFunctionHandle = Arc<FfiFunction>;
23
24pub struct FfiLibrary {
25    library: Library,
26    functions: Vec<(FunctionSignature, FfiFunctionHandle)>,
27    name: String,
28}
29
30impl FfiLibrary {
31    pub fn new(path: impl AsRef<Path>) -> Result<Self, Box<dyn Error>> {
32        let mut path = path.as_ref().to_path_buf();
33        if path.extension().is_none() {
34            path.set_extension(std::env::consts::DLL_EXTENSION);
35        }
36        Ok(Self {
37            library: unsafe { Library::new(path.as_os_str())? },
38            functions: Default::default(),
39            name: path.to_string_lossy().to_string(),
40        })
41    }
42
43    pub fn name(&self) -> &str {
44        &self.name
45    }
46
47    pub fn function(
48        &mut self,
49        signature: FunctionSignature,
50    ) -> Result<FfiFunctionHandle, Box<dyn Error>> {
51        unsafe {
52            let symbol = OsString::from_str(&signature.name)?;
53            let symbol = self
54                .library
55                .get::<unsafe extern "C" fn()>(symbol.as_encoded_bytes())?;
56            let Some(function) = symbol.try_as_raw_ptr() else {
57                return Err(format!("Could not get pointer of function: `{}`", signature).into());
58            };
59            let handle = Arc::new(FfiFunction::from_function_signature(
60                FfiCodePtr(function),
61                &signature,
62            ));
63            for (s, h) in &mut self.functions {
64                if s == &signature {
65                    *h = handle.clone();
66                    return Ok(handle);
67                }
68            }
69            self.functions.push((signature, handle.clone()));
70            Ok(handle)
71        }
72    }
73
74    pub fn find(&self, query: FunctionQuery) -> Option<FfiFunctionHandle> {
75        self.functions.iter().find_map(|(signature, handle)| {
76            if query.is_valid(signature) {
77                Some(handle.clone())
78            } else {
79                None
80            }
81        })
82    }
83}
84
85#[derive(Debug, Clone)]
86pub struct FfiFunction {
87    function: FfiCodePtr,
88    result: Option<TypeHandle>,
89    arguments: Vec<TypeHandle>,
90}
91
92unsafe impl Send for FfiFunction {}
93unsafe impl Sync for FfiFunction {}
94
95impl FfiFunction {
96    pub fn from_function_signature(function: FfiCodePtr, signature: &FunctionSignature) -> Self {
97        FfiFunction {
98            function,
99            result: signature
100                .outputs
101                .iter()
102                .find(|param| param.name == "result")
103                .map(|param| param.type_handle.clone()),
104            arguments: signature
105                .inputs
106                .iter()
107                .map(|param| param.type_handle.clone())
108                .collect(),
109        }
110    }
111
112    pub fn build_function(function: FfiCodePtr, signature: FunctionSignature) -> Function {
113        let ffi = Self::from_function_signature(function, &signature);
114        Function::new(
115            signature,
116            FunctionBody::Closure(Arc::new(move |context, _| unsafe {
117                ffi.call(context).expect("FFI call error");
118            })),
119        )
120    }
121
122    pub fn new(function: FfiCodePtr) -> Self {
123        Self {
124            function,
125            result: Default::default(),
126            arguments: Default::default(),
127        }
128    }
129
130    pub fn with_result(mut self, type_: TypeHandle) -> Self {
131        self.result(type_);
132        self
133    }
134
135    pub fn with_argument(mut self, type_: TypeHandle) -> Self {
136        self.argument(type_);
137        self
138    }
139
140    pub fn result(&mut self, type_: TypeHandle) {
141        self.result = Some(type_);
142    }
143
144    pub fn argument(&mut self, type_: TypeHandle) {
145        self.arguments.push(type_);
146    }
147
148    /// # Safety
149    pub unsafe fn call(&self, context: &mut Context) -> Result<(), Box<dyn Error>> {
150        let mut arguments_data = self
151            .arguments
152            .iter()
153            .map(|type_| {
154                if let Some((_, type_hash, _, data)) = context.stack().pop_raw() {
155                    if type_hash == type_.type_hash() {
156                        Ok(data)
157                    } else {
158                        Err(
159                            format!("Popped value from stack is not `{}` type!", type_.name())
160                                .into(),
161                        )
162                    }
163                } else {
164                    Err(format!("Could not pop `{}` type value from stack!", type_.name()).into())
165                }
166            })
167            .collect::<Result<Vec<_>, Box<dyn Error>>>()?;
168        let mut arguments = arguments_data
169            .iter_mut()
170            .map(|data| data.as_mut_ptr() as *mut c_void)
171            .collect::<Vec<_>>();
172        let mut types = Vec::with_capacity(self.arguments.len() + 1);
173        types.push(
174            self.result
175                .as_ref()
176                .map(Self::make_type)
177                .unwrap_or(ffi_type_void),
178        );
179        for type_ in &self.arguments {
180            types.push(Self::make_type(type_));
181        }
182        let return_type = &mut types[0] as *mut _;
183        let mut argument_types = types[1..]
184            .iter_mut()
185            .map(|type_| type_ as *mut _)
186            .collect::<Vec<_>>();
187        let mut cif = ffi_cif::default();
188        ffi_prep_cif(
189            &mut cif as *mut _,
190            ffi_abi_FFI_DEFAULT_ABI,
191            arguments_data.len() as _,
192            return_type,
193            argument_types.as_mut_ptr(),
194        );
195        let mut result = vec![0u8; return_type.as_ref().unwrap().size];
196        ffi_call(
197            &mut cif as *mut _,
198            Some(*self.function.as_safe_fun()),
199            result.as_mut_ptr() as *mut _,
200            arguments.as_mut_ptr(),
201        );
202        if let Some(type_) = self.result.as_ref() {
203            context.stack().push_raw(
204                *type_.layout(),
205                type_.type_hash(),
206                type_.finalizer(),
207                &result,
208            );
209        }
210        Ok(())
211    }
212
213    fn make_type(type_: &TypeHandle) -> ffi_type {
214        let layout = type_.layout();
215        ffi_type {
216            size: layout.size(),
217            alignment: layout.align() as _,
218            type_: FFI_TYPE_STRUCT as _,
219            elements: null_mut(),
220        }
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use intuicio_core::prelude::*;
228
229    extern "C" fn add(a: i32, b: i32) -> i32 {
230        a + b
231    }
232
233    extern "C" fn ensure_42(v: i32) {
234        assert_eq!(v, 42);
235    }
236
237    fn is_async<T: Send + Sync>() {}
238
239    #[test]
240    fn test_ffi_function() {
241        is_async::<FfiFunction>();
242
243        let registry = Registry::default().with_basic_types();
244        let mut context = Context::new(10240, 10240);
245
246        let i32_type = registry.find_type(TypeQuery::of::<i32>()).unwrap();
247        let ffi_add = FfiFunction::new(FfiCodePtr(add as *mut _))
248            .with_argument(i32_type.clone())
249            .with_argument(i32_type.clone())
250            .with_result(i32_type.clone());
251        let ffi_ensure =
252            FfiFunction::new(FfiCodePtr(ensure_42 as *mut _)).with_argument(i32_type.clone());
253        context.stack().push(2i32);
254        context.stack().push(40i32);
255        unsafe {
256            ffi_add.call(&mut context).unwrap();
257            ffi_ensure.call(&mut context).unwrap();
258        }
259
260        let ffi_add = FfiFunction::from_function_signature(
261            FfiCodePtr(add as *mut _),
262            &function_signature!(&registry => fn add(a: i32, b: i32) -> (result: i32)),
263        );
264        let ffi_ensure = FfiFunction::from_function_signature(
265            FfiCodePtr(ensure_42 as *mut _),
266            &function_signature!(&registry => fn ensure_42(v: i32) -> ()),
267        );
268        context.stack().push(2i32);
269        context.stack().push(40i32);
270        unsafe {
271            ffi_add.call(&mut context).unwrap();
272            ffi_ensure.call(&mut context).unwrap();
273        }
274
275        let ffi_add = FfiFunction::build_function(
276            FfiCodePtr(add as *mut _),
277            function_signature!(&registry => fn add(a: i32, b: i32) -> (result: i32)),
278        );
279        let ffi_ensure = FfiFunction::build_function(
280            FfiCodePtr(ensure_42 as *mut _),
281            function_signature!(&registry => fn ensure_42(v: i32) -> ()),
282        );
283        context.stack().push(2i32);
284        context.stack().push(40i32);
285        ffi_add.invoke(&mut context, &registry);
286        ffi_ensure.invoke(&mut context, &registry);
287    }
288
289    #[test]
290    fn test_ffi_library() {
291        is_async::<FfiLibrary>();
292
293        let registry = Registry::default().with_basic_types();
294        let mut context = Context::new(10240, 10240);
295        let mut lib = FfiLibrary::new("../../target/debug/ffi").unwrap();
296        let ffi_add = lib
297            .function(function_signature!(&registry => fn add(a: i32, b: i32) -> (result: i32)))
298            .unwrap();
299        let ffi_ensure = lib
300            .function(function_signature!(&registry => fn ensure_42(v: i32) -> ()))
301            .unwrap();
302        context.stack().push(2i32);
303        context.stack().push(40i32);
304        unsafe {
305            ffi_add.call(&mut context).unwrap();
306            ffi_ensure.call(&mut context).unwrap();
307        }
308    }
309}