ray_rs_sys/
lib.rs

1#![allow(non_upper_case_globals)]
2#![allow(non_camel_case_types)]
3#![allow(non_snake_case)]
4#![allow(deref_nullptr)]
5
6#[cfg(not(feature = "bazel"))]
7include!(concat!(env!("OUT_DIR"), "/ray_rs_sys_bindgen.rs"));
8
9#[cfg(feature = "bazel")]
10include!(env!("BAZEL_BINDGEN_SOURCE"));
11
12use std::os::raw::*;
13use std::ffi::CString;
14
15// pub type execute_function
16
17struct LaunchConfig {
18    is_driver: bool,
19    code_search_path: CString,
20    head_args: CString,
21}
22
23pub type MaybeExecuteCallback = c_worker_ExecuteCallback;
24
25pub extern "C" fn rust_worker_execute_dummy(
26    _task_type: RayInt,
27    _ray_function_info: RaySlice,
28    _args: RaySlice,
29    _return_values: RaySlice,
30) {
31}
32
33pub mod ray {
34    use super::*;
35    pub fn init_inner(
36        is_driver: bool,
37        f: MaybeExecuteCallback,
38        // d: MaybeBufferDestructor,
39        argc_v: Option<(c_int, *const *const c_char)>
40    ) {
41        unsafe {
42            let mut code_search_path = CString::new("").unwrap();
43            let mut head_args = CString::new("").unwrap();
44
45            c_worker_RegisterExecutionCallback(f);
46            // c_worker_RegisterBufferDestructor(d);
47
48            let (argc, argv) = argc_v.unwrap_or((0, std::ptr::null()));
49
50            c_worker_InitConfig(
51                if is_driver { 1 } else { 0 }, 3, 1,
52                code_search_path.as_ptr() as *mut c_char,
53                head_args.as_ptr() as *mut c_char,
54                argc, argv as *mut *mut c_char,
55            );
56            c_worker_Initialize();
57        }
58    }
59
60    pub fn shutdown() {
61        unsafe {
62            c_worker_Shutdown();
63        }
64    }
65
66    pub fn run() {
67        unsafe {
68            c_worker_Run();
69        }
70    }
71}
72
73pub mod util {
74    use super::dv_as_slice;
75    use std::ffi::CString;
76    pub fn add_local_ref(id: CString) {
77        unsafe {
78            super::c_worker_AddLocalRef(id.into_raw())
79        }
80    }
81
82    pub fn remove_local_ref(id: CString) {
83        unsafe {
84            super::c_worker_RemoveLocalRef(id.into_raw())
85        }
86    }
87
88    pub fn pretty_print_id(id: &CString) -> String {
89        id.as_bytes()
90            .iter()
91            .map(|x| format!("{:02x?}", x))
92            .collect::<Vec<_>>()
93            .join("")
94    }
95
96    pub fn log_internal(msg: String) {
97        unsafe {
98            super::c_worker_Log(std::ffi::CString::new(msg).unwrap().into_raw());
99        }
100    }
101//     pub fn fd_to_cstring(fd: RaySlice) -> CString {
102//         CString::from(fd.data as *c_char)
103//     }
104}
105
106
107pub mod internal {
108    use super::*;
109    // One can use Vec<&'a[u8]> in the function signature instead since SubmitTask is synchronous?
110    pub fn submit(fn_name: CString, args: &mut Vec<Vec<u8>>) -> CString {
111        unsafe {
112            // Create data
113            let mut meta_vec = vec![0u8];
114            let mut data = args
115                .iter_mut()
116                .map(|data_vec| {
117                    c_worker_AllocateDataValue(
118                        // Why is this a void pointer, not a void/char ptr?
119                        (*data_vec).as_mut_ptr(),
120                        data_vec.len() as u64,
121                        std::ptr::null_mut(),
122                        0u64,
123                    )
124                })
125                .collect::<Vec<*mut DataValue>>();
126
127            let mut obj_ids = vec![std::ptr::null_mut()];
128            let mut is_refs = vec![false; args.len()];
129
130            c_worker_SubmitTask(
131                fn_name.into_raw(),
132                is_refs.as_mut_ptr(),
133                data.as_mut_ptr(),
134                std::ptr::null_mut::<*mut c_char>(),
135                data.len() as i32,
136                1,
137                obj_ids.as_mut_ptr()
138            );
139
140            let c_str_id = CString::from_raw(obj_ids[0]);
141            println!("ObjectID: {:x?}", util::pretty_print_id(&c_str_id));
142            c_str_id
143        }
144    }
145
146    pub fn get_slice<'a>(id: CString, timeout: i32) -> &'a mut [u8] {
147        dv_as_slice(get(id, timeout))
148    }
149
150    #[inline]
151    fn get(id: CString, timeout: i32) -> DataValue {
152        let mut data = vec![id.as_ptr()];
153        let mut d_value: Vec<*mut DataValue> = vec![std::ptr::null_mut() as *mut _];
154        unsafe {
155            c_worker_Get(
156                data.as_ptr() as *mut *mut c_char,
157                1,
158                timeout,
159                d_value.as_ptr() as *mut *mut DataValue
160            );
161            *d_value[0] as DataValue
162        }
163    }
164}
165
166pub fn dv_as_slice<'a>(data: DataValue) -> &'a mut [u8] {
167    unsafe {
168        std::slice::from_raw_parts_mut::<u8>(
169            (*data.data).p,
170            (*data.data).size as usize,
171        )
172    }
173}
174
175#[cfg(test)]
176pub mod test {
177    use super::*;
178    #[test]
179    fn test_allocate_data() {
180        let mut data_vec = vec![1u8, 2];
181        let mut meta_vec = vec![3u8, 4];
182        unsafe {
183            let data =
184                c_worker_AllocateDataValue(
185                    data_vec.as_mut_ptr(),
186                    data_vec.len() as u64,
187                    meta_vec.as_mut_ptr(),
188                    meta_vec.len() as u64,
189                );
190            assert_eq!((*(*data).data).p, data_vec.as_mut_ptr());
191            assert_eq!((*(*data).meta).p, meta_vec.as_mut_ptr());
192            assert_eq!((*(*data).data).size, data_vec.len() as u64);
193            assert_eq!((*(*data).meta).size, data_vec.len() as u64);
194        }
195    }
196
197    #[test]
198    fn test_register_callback() {
199        unsafe {
200            assert_eq!(
201                c_worker_RegisterExecutionCallback(
202                    Some(rust_worker_execute_dummy)
203                ),
204                1,
205                "Failed to register execute callback"
206            );
207        }
208    }
209
210    // #[test]
211    // fn test_init_and_shutdown() {
212    //     unsafe {
213    //         c_worker_RegisterExecutionCallback(Some(c_worker_execute));
214    //         let mut code_search_path = CString::new("").unwrap();
215    //         let mut head_args = CString::new("").unwrap();
216    //         c_worker_InitConfig(
217    //             1, 3, 1,
218    //             code_search_path.as_ptr() as *mut c_char,
219    //             head_args.as_ptr() as *mut c_char,
220    //             0, std::ptr::null_mut()
221    //         );
222    //         c_worker_Initialize();
223    //         c_worker_Shutdown();
224    //     }
225    // }
226
227    #[test]
228    fn test_put_get_raw() {
229        ray::init_inner(true, Some(rust_worker_execute_dummy), None);
230        unsafe {
231            // Create data
232            let mut data_vec = vec![1u8, 2];
233            let mut meta_vec = vec![3u8, 4];
234            let mut data = vec![
235                c_worker_AllocateDataValue(
236                    data_vec.as_mut_ptr() as *mut c_void,
237                    data_vec.len() as u64,
238                    meta_vec.as_mut_ptr() as *mut c_void,
239                    meta_vec.len() as u64,
240                )
241            ];
242
243            let mut obj_ids = Vec::<*mut c_char>::new();
244            obj_ids.push(std::ptr::null_mut() as *mut c_char);
245
246            c_worker_Put(
247                obj_ids.as_mut_ptr() as *mut *mut c_char,
248                -1, data.as_mut_ptr(), data.len() as i32,
249            );
250
251            let c_str_id = CString::from_raw(obj_ids[0]);
252            println!("{:x?}", c_str_id);
253
254            let mut get_data: Vec<*mut DataValue> = vec![std::ptr::null_mut() as *mut _];
255
256            c_worker_Get(
257                obj_ids.as_mut_ptr() as *mut *mut c_char,
258                1, -1,
259                get_data.as_mut_ptr() as *mut *mut DataValue
260            );
261
262            let slice = std::slice::from_raw_parts_mut::<u8>(
263                (*(*get_data[0]).data).p as *mut u8,
264                (*(*get_data[0]).data).size as usize,
265            );
266            assert_eq!(slice, &data_vec);
267
268            assert_eq!(dv_as_slice(get(c_str_id, -1)), &data_vec);
269
270            c_worker_Shutdown();
271        }
272    }
273}
274
275type BufferDestructor = extern "C" fn(*mut u8, u64);
276
277// This is how to prevent memory leakage...
278// How does Rust allocate memory...? In terms of malloc slices?
279// Apprently, in terms of malloc slices but in the layout of a type....
280pub extern "C" fn rust_raw_parts_dealloc(ptr: *mut u8, len: u64) {
281    unsafe {
282        std::ptr::drop_in_place(
283            std::ptr::slice_from_raw_parts_mut(ptr, len as usize)
284        )
285    }
286}