postcard_rpc/host_client/
raw_nusb.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
//! Implementation of transport using nusb

use std::future::Future;

use nusb::{
    transfer::{Queue, RequestBuffer, TransferError},
    DeviceInfo, InterfaceInfo,
};
use postcard_schema::Schema;
use serde::de::DeserializeOwned;

use crate::{
    header::VarSeqKind,
    host_client::{HostClient, WireRx, WireSpawn, WireTx},
};

// TODO: These should all be configurable, PRs welcome

/// The Bulk Out Endpoint (0x00 | 0x01): Out EP 1
pub(crate) const BULK_OUT_EP: u8 = 0x01;
/// The Bulk In Endpoint (0x80 | 0x01): In EP 1
pub(crate) const BULK_IN_EP: u8 = 0x81;
/// The size in bytes of the largest possible IN transfer
pub(crate) const MAX_TRANSFER_SIZE: usize = 1024;
/// How many in-flight requests at once - allows nusb to keep pulling frames
/// even if we haven't processed them host-side yet.
pub(crate) const IN_FLIGHT_REQS: usize = 4;
/// How many consecutive IN errors will we try to recover from before giving up?
pub(crate) const MAX_STALL_RETRIES: usize = 10;

/// # `nusb` Constructor Methods
///
/// These methods are used to create a new [HostClient] instance for use with `nusb` and
/// USB bulk transfer encoding.
///
/// **Requires feature**: `raw-nusb`
impl<WireErr> HostClient<WireErr>
where
    WireErr: DeserializeOwned + Schema,
{
    /// Try to create a new link using [`nusb`] for connectivity
    ///
    /// The provided function will be used to find a matching device. The first
    /// matching device will be connected to. `err_uri_path` is
    /// the path associated with the `WireErr` message type.
    ///
    /// Returns an error if no device could be found, or if there was an error
    /// connecting to the device.
    ///
    /// This constructor is available when the `raw-nusb` feature is enabled.
    ///
    /// ## Platform specific support
    ///
    /// When using Windows, the WinUSB driver does not allow enumerating interfaces.
    /// When on windows, this method will ALWAYS try to connect to interface zero.
    /// This limitation may be removed in the future, and if so, will be changed to
    /// look for the first interface with the class of 0xFF.
    ///
    /// ## Example
    ///
    /// ```rust,no_run
    /// use postcard_rpc::host_client::HostClient;
    /// use postcard_rpc::header::VarSeqKind;
    /// use serde::{Serialize, Deserialize};
    /// use postcard_schema::Schema;
    ///
    /// /// A "wire error" type your server can use to respond to any
    /// /// kind of request, for example if deserializing a request fails
    /// #[derive(Debug, PartialEq, Schema, Serialize, Deserialize)]
    /// pub enum Error {
    ///    SomethingBad
    /// }
    ///
    /// let client = HostClient::<Error>::try_new_raw_nusb(
    ///     // Find the first device with the serial 12345678
    ///     |d| d.serial_number() == Some("12345678"),
    ///     // the URI/path for `Error` messages
    ///     "error",
    ///     // Outgoing queue depth in messages
    ///     8,
    ///     // Use one-byte sequence numbers
    ///     VarSeqKind::Seq1,
    /// ).unwrap();
    /// ```
    pub fn try_new_raw_nusb<F: FnMut(&DeviceInfo) -> bool>(
        func: F,
        err_uri_path: &str,
        outgoing_depth: usize,
        seq_no_kind: VarSeqKind,
    ) -> Result<Self, String> {
        let x = nusb::list_devices()
            .map_err(|e| format!("Error listing devices: {e:?}"))?
            .find(func)
            .ok_or_else(|| String::from("Failed to find matching nusb device!"))?;

        // NOTE: We can't enumerate interfaces on Windows. For now, just use
        // a hardcoded interface of zero instead of trying to find the right one
        #[cfg(not(target_os = "windows"))]
        let interface_id = x
            .interfaces()
            .position(|i| i.class() == 0xFF)
            .ok_or_else(|| String::from("Failed to find matching interface!!"))?;

        #[cfg(target_os = "windows")]
        let interface_id = 0;

        let dev = x
            .open()
            .map_err(|e| format!("Failed opening device: {e:?}"))?;
        let interface = dev
            .claim_interface(interface_id as u8)
            .map_err(|e| format!("Failed claiming interface: {e:?}"))?;

        let mut mps: Option<usize> = None;
        if let Ok(config) = dev.active_configuration() {
            for ias in config.interface_alt_settings() {
                for ep in ias.endpoints() {
                    if ep.address() == BULK_OUT_EP {
                        mps = Some(match mps.take() {
                            Some(old) => old.min(ep.max_packet_size()),
                            None => ep.max_packet_size(),
                        });
                    }
                }
            }
        }

        if let Some(max_packet_size) = &mps {
            tracing::debug!(max_packet_size, "Detected max packet size");
        } else {
            tracing::warn!("Unable to detect Max Packet Size!");
        };

        let boq = interface.bulk_out_queue(BULK_OUT_EP);
        let biq = interface.bulk_in_queue(BULK_IN_EP);

        Ok(HostClient::new_with_wire(
            NusbWireTx {
                boq,
                max_packet_size: mps,
            },
            NusbWireRx {
                biq,
                consecutive_errs: 0,
            },
            NusbSpawn,
            seq_no_kind,
            err_uri_path,
            outgoing_depth,
        ))
    }
    /// Try to create a new link using [`nusb`] for connectivity
    ///
    /// The provided function will be used to find a matching device. The first
    /// matching device will be connected to. `err_uri_path` is
    /// the path associated with the `WireErr` message type.
    ///
    /// Returns an error if no device or interface could be found, or if there was an error
    /// connecting to the device or interface.
    ///
    /// This constructor is available when the `raw-nusb` feature is enabled.
    ///
    /// ## Platform specific support
    ///
    /// When using Windows, the WinUSB driver does not allow enumerating interfaces.
    /// Therefore, this constructor is not available on windows. This limitation may
    /// be removed in the future.
    ///
    /// ## Example
    ///
    /// ```rust,no_run
    /// use postcard_rpc::host_client::HostClient;
    /// use postcard_rpc::header::VarSeqKind;
    /// use serde::{Serialize, Deserialize};
    /// use postcard_schema::Schema;
    ///
    /// /// A "wire error" type your server can use to respond to any
    /// /// kind of request, for example if deserializing a request fails
    /// #[derive(Debug, PartialEq, Schema, Serialize, Deserialize)]
    /// pub enum Error {
    ///    SomethingBad
    /// }
    ///
    /// let client = HostClient::<Error>::try_new_raw_nusb_with_interface(
    ///     // Find the first device with the serial 12345678
    ///     |d| d.serial_number() == Some("12345678"),
    ///     // Find the "Vendor Specific" interface
    ///     |i| i.class() == 0xFF,
    ///     // the URI/path for `Error` messages
    ///     "error",
    ///     // Outgoing queue depth in messages
    ///     8,
    ///     // Use one-byte sequence numbers
    ///     VarSeqKind::Seq1,
    /// ).unwrap();
    /// ```
    #[cfg(not(target_os = "windows"))]
    pub fn try_new_raw_nusb_with_interface<
        F1: FnMut(&DeviceInfo) -> bool,
        F2: FnMut(&InterfaceInfo) -> bool,
    >(
        device_func: F1,
        interface_func: F2,
        err_uri_path: &str,
        outgoing_depth: usize,
        seq_no_kind: VarSeqKind,
    ) -> Result<Self, String> {
        let x = nusb::list_devices()
            .map_err(|e| format!("Error listing devices: {e:?}"))?
            .find(device_func)
            .ok_or_else(|| String::from("Failed to find matching nusb device!"))?;
        let interface_id = x
            .interfaces()
            .position(interface_func)
            .ok_or_else(|| String::from("Failed to find matching interface!!"))?;
        let dev = x
            .open()
            .map_err(|e| format!("Failed opening device: {e:?}"))?;
        let interface = dev
            .claim_interface(interface_id as u8)
            .map_err(|e| format!("Failed claiming interface: {e:?}"))?;

        let mut mps: Option<usize> = None;
        if let Ok(config) = dev.active_configuration() {
            for ias in config.interface_alt_settings() {
                for ep in ias.endpoints() {
                    if ep.address() == BULK_OUT_EP {
                        mps = Some(match mps.take() {
                            Some(old) => old.min(ep.max_packet_size()),
                            None => ep.max_packet_size(),
                        });
                    }
                }
            }
        }

        if let Some(max_packet_size) = &mps {
            tracing::debug!(max_packet_size, "Detected max packet size");
        } else {
            tracing::warn!("Unable to detect Max Packet Size!");
        };

        let boq = interface.bulk_out_queue(BULK_OUT_EP);
        let biq = interface.bulk_in_queue(BULK_IN_EP);

        Ok(HostClient::new_with_wire(
            NusbWireTx {
                boq,
                max_packet_size: mps,
            },
            NusbWireRx {
                biq,
                consecutive_errs: 0,
            },
            NusbSpawn,
            seq_no_kind,
            err_uri_path,
            outgoing_depth,
        ))
    }

    /// Create a new link using [`nusb`] for connectivity
    ///
    /// Panics if connection fails. See [`Self::try_new_raw_nusb()`] for more details.
    ///
    /// This constructor is available when the `raw-nusb` feature is enabled.
    ///
    /// ## Example
    ///
    /// ```rust,no_run
    /// use postcard_rpc::host_client::HostClient;
    /// use postcard_rpc::header::VarSeqKind;
    /// use serde::{Serialize, Deserialize};
    /// use postcard_schema::Schema;
    ///
    /// /// A "wire error" type your server can use to respond to any
    /// /// kind of request, for example if deserializing a request fails
    /// #[derive(Debug, PartialEq, Schema, Serialize, Deserialize)]
    /// pub enum Error {
    ///    SomethingBad
    /// }
    ///
    /// let client = HostClient::<Error>::new_raw_nusb(
    ///     // Find the first device with the serial 12345678
    ///     |d| d.serial_number() == Some("12345678"),
    ///     // the URI/path for `Error` messages
    ///     "error",
    ///     // Outgoing queue depth in messages
    ///     8,
    ///     // Use one-byte sequence numbers
    ///     VarSeqKind::Seq1,
    /// );
    /// ```
    pub fn new_raw_nusb<F: FnMut(&DeviceInfo) -> bool>(
        func: F,
        err_uri_path: &str,
        outgoing_depth: usize,
        seq_no_kind: VarSeqKind,
    ) -> Self {
        Self::try_new_raw_nusb(func, err_uri_path, outgoing_depth, seq_no_kind)
            .expect("should have found nusb device")
    }
}

//////////////////////////////////////////////////////////////////////////////
// Wire Interface Implementation
//////////////////////////////////////////////////////////////////////////////

/// NUSB Wire Interface Implementor
///
/// Uses Tokio for spawning tasks
struct NusbSpawn;

impl WireSpawn for NusbSpawn {
    fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
        // Explicitly drop the joinhandle as it impls Future and this makes
        // clippy mad if you just let it drop implicitly
        core::mem::drop(tokio::task::spawn(fut));
    }
}

/// NUSB Wire Transmit Interface Implementor
struct NusbWireTx {
    boq: Queue<Vec<u8>>,
    max_packet_size: Option<usize>,
}

#[derive(thiserror::Error, Debug)]
enum NusbWireTxError {
    #[error("Transfer Error on Send")]
    Transfer(#[from] TransferError),
}

impl WireTx for NusbWireTx {
    type Error = NusbWireTxError;

    #[inline]
    fn send(&mut self, data: Vec<u8>) -> impl Future<Output = Result<(), Self::Error>> + Send {
        self.send_inner(data)
    }
}

impl NusbWireTx {
    async fn send_inner(&mut self, data: Vec<u8>) -> Result<(), NusbWireTxError> {
        let needs_zlp = if let Some(mps) = self.max_packet_size {
            (data.len() % mps) == 0
        } else {
            true
        };

        self.boq.submit(data);

        // Append ZLP if we are a multiple of max packet
        if needs_zlp {
            self.boq.submit(vec![]);
        }

        let send_res = self.boq.next_complete().await;
        if let Err(e) = send_res.status {
            tracing::error!("Output Queue Error: {e:?}");
            return Err(e.into());
        }

        if needs_zlp {
            let send_res = self.boq.next_complete().await;
            if let Err(e) = send_res.status {
                tracing::error!("Output Queue Error: {e:?}");
                return Err(e.into());
            }
        }

        Ok(())
    }
}

/// NUSB Wire Receive Interface Implementor
struct NusbWireRx {
    biq: Queue<RequestBuffer>,
    consecutive_errs: usize,
}

#[derive(thiserror::Error, Debug)]
enum NusbWireRxError {
    #[error("Transfer Error on Recv")]
    Transfer(#[from] TransferError),
}

impl WireRx for NusbWireRx {
    type Error = NusbWireRxError;

    #[inline]
    fn receive(&mut self) -> impl Future<Output = Result<Vec<u8>, Self::Error>> + Send {
        self.recv_inner()
    }
}

impl NusbWireRx {
    async fn recv_inner(&mut self) -> Result<Vec<u8>, NusbWireRxError> {
        loop {
            // Rehydrate the queue
            let pending = self.biq.pending();
            for _ in 0..(IN_FLIGHT_REQS.saturating_sub(pending)) {
                self.biq.submit(RequestBuffer::new(MAX_TRANSFER_SIZE));
            }

            let res = self.biq.next_complete().await;

            if let Err(e) = res.status {
                self.consecutive_errs += 1;

                tracing::error!(
                    "In Worker error: {e:?}, consecutive: {}",
                    self.consecutive_errs
                );

                // Docs only recommend this for Stall, but it seems to work with
                // UNKNOWN on MacOS as well, todo: look into why!
                //
                // Update: This stall condition seems to have been due to an errata in the
                // STM32F4 USB hardware. See https://github.com/embassy-rs/embassy/pull/2823
                //
                // It is now questionable whether we should be doing this stall recovery at all,
                // as it likely indicates an issue with the connected USB device
                let recoverable = match e {
                    TransferError::Stall | TransferError::Unknown => {
                        self.consecutive_errs <= MAX_STALL_RETRIES
                    }
                    TransferError::Cancelled => false,
                    TransferError::Disconnected => false,
                    TransferError::Fault => false,
                };

                let fatal = if recoverable {
                    tracing::warn!("Attempting stall recovery!");

                    // Stall recovery shouldn't be used with in-flight requests, so
                    // cancel them all. They'll still pop out of next_complete.
                    self.biq.cancel_all();
                    tracing::info!("Cancelled all in-flight requests");

                    // Now we need to join all in flight requests
                    for _ in 0..(IN_FLIGHT_REQS - 1) {
                        let res = self.biq.next_complete().await;
                        tracing::info!("Drain state: {:?}", res.status);
                    }

                    // Now we can mark the stall as clear
                    match self.biq.clear_halt() {
                        Ok(()) => false,
                        Err(e) => {
                            tracing::error!("Failed to clear stall: {e:?}, Fatal.");
                            true
                        }
                    }
                } else {
                    tracing::error!(
                        "Giving up after {} errors in a row, final error: {e:?}",
                        self.consecutive_errs
                    );
                    true
                };

                if fatal {
                    tracing::error!("Fatal Error, exiting");
                    // When we close the channel, all pending receivers and subscribers
                    // will be notified
                    return Err(e.into());
                } else {
                    tracing::info!("Potential recovery, resuming NusbWireRx::recv_inner");
                    continue;
                }
            }

            // If we get a good decode, clear the error flag
            if self.consecutive_errs != 0 {
                tracing::info!("Clearing consecutive error counter after good header decode");
                self.consecutive_errs = 0;
            }

            return Ok(res.data);
        }
    }
}