Skip to content

Commit 06ae4a7

Browse files
Add tdt duration to APIs (k2-fsa#2514)
1 parent ba13109 commit 06ae4a7

File tree

12 files changed

+219
-24
lines changed

12 files changed

+219
-24
lines changed

c-api-examples/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ target_link_libraries(fire-red-asr-c-api sherpa-onnx-c-api)
5959
add_executable(nemo-canary-c-api nemo-canary-c-api.c)
6060
target_link_libraries(nemo-canary-c-api sherpa-onnx-c-api)
6161

62+
add_executable(nemo-parakeet-c-api nemo-parakeet-c-api.c)
63+
target_link_libraries(nemo-parakeet-c-api sherpa-onnx-c-api)
64+
6265
add_executable(sense-voice-c-api sense-voice-c-api.c)
6366
target_link_libraries(sense-voice-c-api sherpa-onnx-c-api)
6467

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// c-api-examples/nemo-parakeet-c-api.c
2+
// Example using the C API and sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8 model
3+
// Prints recognized text, per-token timestamps, and durations
4+
5+
#include <stdio.h>
6+
#include <stdlib.h>
7+
#include <string.h>
8+
9+
#include "sherpa-onnx/c-api/c-api.h"
10+
11+
int32_t main() {
12+
const char *wav_filename =
13+
"./sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/test_wavs/en.wav";
14+
const char *encoder_filename =
15+
"sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/encoder.int8.onnx";
16+
const char *decoder_filename =
17+
"sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/decoder.int8.onnx";
18+
const char *joiner_filename =
19+
"sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/joiner.int8.onnx";
20+
const char *tokens_filename =
21+
"sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/tokens.txt";
22+
const char *provider = "cpu";
23+
24+
if (!SherpaOnnxFileExists(wav_filename)) {
25+
fprintf(stderr, "File not found: %s\n", wav_filename);
26+
return -1;
27+
}
28+
const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename);
29+
if (wave == NULL) {
30+
fprintf(stderr, "Failed to read or parse %s (not a valid mono 16-bit WAVE file)\n", wav_filename);
31+
return -1;
32+
}
33+
34+
SherpaOnnxOfflineModelConfig offline_model_config;
35+
memset(&offline_model_config, 0, sizeof(offline_model_config));
36+
offline_model_config.debug = 0;
37+
offline_model_config.num_threads = 1;
38+
offline_model_config.provider = provider;
39+
offline_model_config.tokens = tokens_filename;
40+
offline_model_config.transducer.encoder = encoder_filename;
41+
offline_model_config.transducer.decoder = decoder_filename;
42+
offline_model_config.transducer.joiner = joiner_filename;
43+
44+
SherpaOnnxOfflineRecognizerConfig recognizer_config;
45+
memset(&recognizer_config, 0, sizeof(recognizer_config));
46+
recognizer_config.decoding_method = "greedy_search";
47+
recognizer_config.model_config = offline_model_config;
48+
49+
const SherpaOnnxOfflineRecognizer *recognizer =
50+
SherpaOnnxCreateOfflineRecognizer(&recognizer_config);
51+
if (recognizer == NULL) {
52+
fprintf(stderr, "Please check your config!\n");
53+
SherpaOnnxFreeWave(wave);
54+
return -1;
55+
}
56+
57+
const SherpaOnnxOfflineStream *stream =
58+
SherpaOnnxCreateOfflineStream(recognizer);
59+
if (stream == NULL) {
60+
fprintf(stderr, "Failed to create offline stream.\n");
61+
SherpaOnnxDestroyOfflineRecognizer(recognizer);
62+
SherpaOnnxFreeWave(wave);
63+
return -1;
64+
}
65+
66+
SherpaOnnxAcceptWaveformOffline(stream, wave->sample_rate, wave->samples,
67+
wave->num_samples);
68+
SherpaOnnxDecodeOfflineStream(recognizer, stream);
69+
const SherpaOnnxOfflineRecognizerResult *result =
70+
SherpaOnnxGetOfflineStreamResult(stream);
71+
72+
printf("Recognized text: %s\n", result->text);
73+
74+
if (result->tokens_arr && result->timestamps && result->durations) {
75+
printf("Token\tTimestamp\tDuration\n");
76+
for (int32_t i = 0; i < result->count; ++i) {
77+
printf("%s\t%.2f\t%.2f\n", result->tokens_arr[i], result->timestamps[i], result->durations[i]);
78+
}
79+
} else {
80+
printf("Timestamps or durations not available.\n");
81+
}
82+
83+
SherpaOnnxDestroyOfflineRecognizerResult(result);
84+
SherpaOnnxDestroyOfflineStream(stream);
85+
SherpaOnnxDestroyOfflineRecognizer(recognizer);
86+
SherpaOnnxFreeWave(wave);
87+
88+
return 0;
89+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Example using the sherpa-onnx Python API and sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8 model
2+
# Prints recognized text, per-token timestamps, and durations
3+
4+
import os
5+
import sys
6+
import sherpa_onnx
7+
import soundfile as sf
8+
9+
wav_filename = "./sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/test_wavs/en.wav"
10+
encoder = "./sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/encoder.int8.onnx"
11+
decoder = "./sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/decoder.int8.onnx"
12+
joiner = "./sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/joiner.int8.onnx"
13+
tokens = "./sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8/tokens.txt"
14+
15+
if not os.path.exists(wav_filename):
16+
print(f"File not found: {wav_filename}")
17+
sys.exit(1)
18+
19+
20+
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
21+
encoder,
22+
decoder,
23+
joiner,
24+
tokens,
25+
num_threads=1,
26+
provider="cpu",
27+
debug=False,
28+
decoding_method="greedy_search",
29+
model_type="nemo_transducer"
30+
)
31+
32+
audio, sample_rate = sf.read(wav_filename, dtype="float32", always_2d=True)
33+
audio = audio[:, 0] # use first channel if multi-channel
34+
stream = recognizer.create_stream()
35+
stream.accept_waveform(sample_rate, audio)
36+
recognizer.decode_stream(stream)
37+
result = stream.result
38+
39+
print(f"Recognized text: {result.text}")
40+
41+
if hasattr(result, "tokens") and hasattr(result, "timestamps") and hasattr(result, "durations"):
42+
print("Token\tTimestamp\tDuration")
43+
for token, ts, dur in zip(result.tokens, result.timestamps, result.durations):
44+
print(f"{token}\t{ts:.2f}\t{dur:.2f}")
45+
else:
46+
print("Timestamps or durations not available.")

scripts/go/sherpa_onnx.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,7 @@ type OfflineRecognizerResult struct {
523523
Text string
524524
Tokens []string
525525
Timestamps []float32
526+
Durations []float32
526527
Lang string
527528
Emotion string
528529
Event string
@@ -872,13 +873,19 @@ func (s *OfflineStream) GetResult() *OfflineRecognizerResult {
872873
for i := 0; i < n; i++ {
873874
result.Tokens[i] = C.GoString(tokens[i])
874875
}
875-
if p.timestamps == nil {
876-
return result
876+
if p.timestamps != nil {
877+
result.Timestamps = make([]float32, n)
878+
timestamps := unsafe.Slice(p.timestamps, n)
879+
for i := 0; i < n; i++ {
880+
result.Timestamps[i] = float32(timestamps[i])
881+
}
877882
}
878-
result.Timestamps = make([]float32, n)
879-
timestamps := unsafe.Slice(p.timestamps, n)
880-
for i := 0; i < n; i++ {
881-
result.Timestamps[i] = float32(timestamps[i])
883+
if p.durations != nil {
884+
result.Durations = make([]float32, n)
885+
durations := unsafe.Slice(p.durations, n)
886+
for i := 0; i < n; i++ {
887+
result.Durations[i] = float32(durations[i])
888+
}
882889
}
883890
return result
884891
}

sherpa-onnx/c-api/c-api.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,14 @@ const SherpaOnnxOfflineRecognizerResult *SherpaOnnxGetOfflineStreamResult(
689689
r->timestamps = nullptr;
690690
}
691691

692+
if (!result.durations.empty() && result.durations.size() == r->count) {
693+
r->durations = new float[r->count];
694+
std::copy(result.durations.begin(), result.durations.end(),
695+
r->durations);
696+
} else {
697+
r->durations = nullptr;
698+
}
699+
692700
r->tokens = tokens;
693701
} else {
694702
r->count = 0;
@@ -705,6 +713,7 @@ void SherpaOnnxDestroyOfflineRecognizerResult(
705713
if (r) {
706714
delete[] r->text;
707715
delete[] r->timestamps;
716+
delete[] r->durations;
708717
delete[] r->tokens;
709718
delete[] r->tokens_arr;
710719
delete[] r->json;

sherpa-onnx/c-api/c-api.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,10 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
614614
// It is NULL if the model does not support timestamps
615615
float *timestamps;
616616

617+
// Pointer to continuous memory which holds durations (in seconds) for each token
618+
// It is NULL if the model does not support durations
619+
float *durations;
620+
617621
// number of entries in timestamps
618622
int32_t count;
619623

@@ -631,6 +635,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
631635
* "text": "The recognition result",
632636
* "tokens": [x, x, x],
633637
* "timestamps": [x, x, x],
638+
* "durations": [x, x, x],
634639
* "segment": x,
635640
* "start_time": x,
636641
* "is_final": true|false

sherpa-onnx/csrc/offline-recognizer-transducer-impl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ static OfflineRecognitionResult Convert(
3636
OfflineRecognitionResult r;
3737
r.tokens.reserve(src.tokens.size());
3838
r.timestamps.reserve(src.timestamps.size());
39+
r.durations.reserve(src.durations.size());
3940

4041
std::string text;
4142
for (auto i : src.tokens) {
@@ -66,6 +67,11 @@ static OfflineRecognitionResult Convert(
6667
r.timestamps.push_back(time);
6768
}
6869

70+
// Copy durations (if present)
71+
for (auto d : src.durations) {
72+
r.durations.push_back(d * frame_shift_s);
73+
}
74+
6975
return r;
7076
}
7177

sherpa-onnx/csrc/offline-stream.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,18 @@ std::string OfflineRecognitionResult::AsJsonString() const {
396396
}
397397
os << "], ";
398398

399+
os << "\""
400+
<< "durations"
401+
<< "\""
402+
<< ": ";
403+
os << "[";
404+
sep = "";
405+
for (auto d : durations) {
406+
os << sep << std::fixed << std::setprecision(2) << d;
407+
sep = ", ";
408+
}
409+
os << "], ";
410+
399411
os << "\""
400412
<< "tokens"
401413
<< "\""

sherpa-onnx/csrc/offline-stream.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ struct OfflineRecognitionResult {
3838
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
3939
std::vector<float> timestamps;
4040

41+
/// durations[i] contains the duration (in seconds) for tokens[i] (TDT models only)
42+
std::vector<float> durations;
43+
4144
std::vector<int32_t> words;
4245

4346
std::string AsJsonString() const;

sherpa-onnx/csrc/offline-transducer-decoder.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ struct OfflineTransducerDecoderResult {
1919
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
2020
/// Note: The index is after subsampling
2121
std::vector<int32_t> timestamps;
22+
23+
/// durations[i] contains the duration for tokens[i] in output frames
24+
/// (post-subsampling). It is converted to seconds by higher layers
25+
/// (e.g., Convert() in offline-recognizer-transducer-impl.h).
26+
std::vector<float> durations;
2227
};
2328

2429
class OfflineTransducerDecoder {

0 commit comments

Comments
 (0)