sherpa_rs/
audio_tag.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
use eyre::{bail, Result};

use crate::{
    get_default_provider,
    utils::{cstr_to_string, RawCStr},
};

#[derive(Debug, Default, Clone)]
pub struct AudioTagConfig {
    pub model: String,
    pub labels: String,
    pub top_k: i32,
    pub ced: Option<String>,
    pub debug: bool,
    pub num_threads: Option<i32>,
    pub provider: Option<String>,
}

pub struct AudioTag {
    audio_tag: *const sherpa_rs_sys::SherpaOnnxAudioTagging,
    config: AudioTagConfig,
}

impl AudioTag {
    pub fn new(config: AudioTagConfig) -> Result<Self> {
        let config_clone = config.clone();

        let model = RawCStr::new(&config.model);
        let ced = RawCStr::new(&config.ced.unwrap_or_default());
        let labels = RawCStr::new(&config.labels);
        let provider = RawCStr::new(&config.provider.unwrap_or(get_default_provider()));

        let sherpa_config = sherpa_rs_sys::SherpaOnnxAudioTaggingConfig {
            model: sherpa_rs_sys::SherpaOnnxAudioTaggingModelConfig {
                zipformer: sherpa_rs_sys::SherpaOnnxOfflineZipformerAudioTaggingModelConfig {
                    model: model.as_ptr(),
                },
                ced: ced.as_ptr(),
                num_threads: config.num_threads.unwrap_or(1),
                debug: config.debug.into(),
                provider: provider.as_ptr(),
            },
            labels: labels.as_ptr(),
            top_k: config.top_k,
        };
        let audio_tag = unsafe { sherpa_rs_sys::SherpaOnnxCreateAudioTagging(&sherpa_config) };

        if audio_tag.is_null() {
            bail!("Failed to create audio tagging")
        }
        Ok(Self {
            audio_tag,
            config: config_clone,
        })
    }

    pub fn compute(&mut self, samples: Vec<f32>, sample_rate: u32) -> Vec<String> {
        let mut events = Vec::new();
        unsafe {
            let stream = sherpa_rs_sys::SherpaOnnxAudioTaggingCreateOfflineStream(self.audio_tag);
            sherpa_rs_sys::SherpaOnnxAcceptWaveformOffline(
                stream,
                sample_rate as i32,
                samples.as_ptr(),
                samples.len() as i32,
            );
            let results = sherpa_rs_sys::SherpaOnnxAudioTaggingCompute(
                self.audio_tag,
                stream,
                self.config.top_k,
            );

            for i in 0..self.config.top_k {
                let event = *results.add(i.try_into().unwrap());
                let event_name = cstr_to_string((*event).name);
                events.push(event_name);
            }
        }
        events
    }
}