sherpa_rs/
punctuate.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
use eyre::{bail, Result};

use crate::{
    get_default_provider,
    utils::{cstr_to_string, RawCStr},
};

#[derive(Debug, Default, Clone)]
pub struct PunctuationConfig {
    pub model: String,
    pub debug: bool,
    pub num_threads: Option<i32>,
    pub provider: Option<String>,
}

pub struct Punctuation {
    audio_punctuation: *const sherpa_rs_sys::SherpaOnnxOfflinePunctuation,
}

impl Punctuation {
    pub fn new(config: PunctuationConfig) -> Result<Self> {
        let model = RawCStr::new(&config.model);
        let provider = RawCStr::new(&config.provider.unwrap_or(if cfg!(target_os = "macos") {
            // TODO: sherpa-onnx/issues/1448
            "cpu".into()
        } else {
            get_default_provider()
        }));

        let sherpa_config = sherpa_rs_sys::SherpaOnnxOfflinePunctuationConfig {
            model: sherpa_rs_sys::SherpaOnnxOfflinePunctuationModelConfig {
                ct_transformer: model.as_ptr(),
                num_threads: config.num_threads.unwrap_or(1),
                debug: config.debug.into(),
                provider: provider.as_ptr(),
            },
        };
        let audio_punctuation =
            unsafe { sherpa_rs_sys::SherpaOnnxCreateOfflinePunctuation(&sherpa_config) };

        if audio_punctuation.is_null() {
            bail!("Failed to create audio punctuation")
        }
        Ok(Self { audio_punctuation })
    }

    pub fn add_punctuation(&mut self, text: &str) -> String {
        let text = RawCStr::new(text);
        unsafe {
            let text_with_punct_ptr = sherpa_rs_sys::SherpaOfflinePunctuationAddPunct(
                self.audio_punctuation,
                text.as_ptr(),
            );
            let text_with_punct = cstr_to_string(text_with_punct_ptr);
            sherpa_rs_sys::SherpaOfflinePunctuationFreeText(text_with_punct_ptr);
            text_with_punct
        }
    }
}

unsafe impl Send for Punctuation {}
unsafe impl Sync for Punctuation {}

impl Drop for Punctuation {
    fn drop(&mut self) {
        unsafe {
            sherpa_rs_sys::SherpaOnnxDestroyOfflinePunctuation(self.audio_punctuation);
        }
    }
}