Skip to content

Commit 3b2cd79

Browse files
committed
feat(kokoro): F16-quantize voice_tensors
1 parent 0b42010 commit 3b2cd79

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

src/kokoro_model.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "kokoro_model.h"
22

3+
#define ggml_cast_if_needed(ctx, x, qtype) (x->type == qtype ? x : ggml_cast(ctx, x, qtype))
4+
35
static struct ggml_tensor * build_albert_attn_mask(ggml_context * ctx, struct kokoro_duration_context *kctx, const kokoro_ubatch & batch) {
46
kctx->attn_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, (int64_t) batch.n_tokens, (int64_t) batch.n_tokens);
57
ggml_set_input(kctx->attn_mask);
@@ -943,7 +945,7 @@ struct ggml_cgraph * kokoro_duration_runner::build_kokoro_duration_graph(kokoro_
943945
// In order to side step this problem I computed the graph and determined the size in advance and use that constant value here.
944946
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx, 110000, false);
945947

946-
struct ggml_tensor * voice = model->voices[kctx->voice];
948+
struct ggml_tensor * voice = ggml_cast_if_needed(ctx, model->voices[kctx->voice], GGML_TYPE_F32);
947949
struct ggml_tensor * cur;
948950
struct ggml_tensor * inpL;
949951

@@ -1146,7 +1148,7 @@ struct ggml_cgraph * kokoro_runner::build_kokoro_graph(kokoro_ubatch & batch) {
11461148
// In order to side step this problem I computed the graph and determined the size in advance and use that constant value here.
11471149
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx, 570000, false);
11481150

1149-
struct ggml_tensor * voice = model->voices[kctx->voice];
1151+
struct ggml_tensor * voice = ggml_cast_if_needed(ctx, model->voices[kctx->voice], GGML_TYPE_F32);
11501152
struct ggml_tensor * style_half = ggml_view_1d(ctx, voice, voice->ne[0]/2, voice->ne[0] / 2 * voice->nb[0] + (batch.n_tokens - 3) * voice->nb[1]);
11511153
struct ggml_tensor * cur;
11521154

src/tts.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,7 @@ void update_conditional_prompt(tts_runner * runner, const std::string file_path,
191191
}
192192

193193
bool kokoro_is_f16_compatible(std::string name) {
194-
return name.find("voice_tensors") == std::string::npos &&
195-
name.find("bias") == std::string::npos &&
194+
return name.find("bias") == std::string::npos &&
196195
name.find("gamma") == std::string::npos &&
197196
name.find("beta") == std::string::npos &&
198197
name.find("alpha") == std::string::npos &&

0 commit comments

Comments
 (0)