Skip to content

Commit f3e41de

Browse files
committed
Merge remote-tracking branch 'upstream/master'
2 parents ec9eb18 + 36eed0c commit f3e41de

31 files changed

+2992
-2040
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
1010

1111
### Hot topics
1212

13-
- ⚠️ **Upcoming change that might break functionality. Help with testing is needed:** https://github.com/ggerganov/llama.cpp/pull/3912
13+
- *No hot topics atm. Open to suggestions about what is hot today*
1414

1515
----
1616

@@ -93,6 +93,7 @@ as the main playground for developing new features for the [ggml](https://github
9393
- [X] [Persimmon 8B](https://github.com/ggerganov/llama.cpp/pull/3410)
9494
- [X] [MPT](https://github.com/ggerganov/llama.cpp/pull/3417)
9595
- [X] [Bloom](https://github.com/ggerganov/llama.cpp/pull/3553)
96+
- [X] [StableLM-3b-4e1t](https://github.com/ggerganov/llama.cpp/pull/3586)
9697

9798

9899
**Bindings:**

common/train.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ struct train_state * init_train_state() {
3232
state->opt = new struct ggml_opt_context;
3333
state->opt->ctx = NULL;
3434
state->opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
35+
state->opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
3536
state->opt->loss_after = 0.0f;
3637

3738
return state;

common/train.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#include "ggml.h"
1010
#include "llama.h"
1111

12+
#define LLAMA_TRAIN_MAX_NODES 16384
13+
1214
typedef std::string mt19937_state;
1315

1416
struct train_state {

convert-hf-to-gguf.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,6 @@ def load_hparams(dir_model):
150150

151151
@staticmethod
152152
def from_model_architecture(model_architecture):
153-
if model_architecture == "StableLMEpochForCausalLM":
154-
return StableLMModel
155153
if model_architecture == "GPTNeoXForCausalLM":
156154
return GPTNeoXModel
157155
if model_architecture == "BloomForCausalLM":
@@ -168,6 +166,8 @@ def from_model_architecture(model_architecture):
168166
return RefactModel
169167
if model_architecture == "PersimmonForCausalLM":
170168
return PersimmonModel
169+
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
170+
return StableLMModel
171171
return Model
172172

173173
def _is_model_safetensors(self) -> bool:
@@ -201,6 +201,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
201201
return gguf.MODEL_ARCH.REFACT
202202
if arch == "PersimmonForCausalLM":
203203
return gguf.MODEL_ARCH.PERSIMMON
204+
if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
205+
return gguf.MODEL_ARCH.STABLELM
204206

205207
raise NotImplementedError(f'Architecture "{arch}" not supported!')
206208

@@ -294,15 +296,6 @@ def _set_vocab_sentencepiece(self):
294296
special_vocab.add_to_gguf(self.gguf_writer)
295297

296298

297-
class StableLMModel(Model):
298-
def set_gguf_parameters(self):
299-
super().set_gguf_parameters()
300-
self.gguf_writer.add_rope_dimension_count(
301-
int(self.hparams["rope_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
302-
)
303-
self.gguf_writer.add_layer_norm_eps(1e-5)
304-
305-
306299
class GPTNeoXModel(Model):
307300
def set_gguf_parameters(self):
308301
block_count = self.hparams["num_hidden_layers"]
@@ -824,6 +817,21 @@ def write_tensors(self):
824817
self.gguf_writer.add_tensor(new_name, data)
825818

826819

820+
class StableLMModel(Model):
821+
def set_gguf_parameters(self):
822+
hparams = self.hparams
823+
block_count = hparams["num_hidden_layers"]
824+
825+
self.gguf_writer.add_name(dir_model.name)
826+
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
827+
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
828+
self.gguf_writer.add_block_count(block_count)
829+
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
830+
self.gguf_writer.add_rope_dimension_count(int(hparams["rope_pct"]*(hparams["hidden_size"] // hparams["num_attention_heads"])))
831+
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
832+
self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
833+
self.gguf_writer.add_layer_norm_eps(1e-5)
834+
827835
###### CONVERSION LOGIC ######
828836

829837
def parse_args() -> argparse.Namespace:

convert.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,8 @@ def load_some_model(path: Path) -> ModelPlus:
10361036
# Be extra-friendly and accept either a file or a directory:
10371037
if path.is_dir():
10381038
# Check if it's a set of safetensors files first
1039-
files = list(path.glob("model-00001-of-*.safetensors"))
1039+
globs = ["model-00001-of-*.safetensors", "model.safetensors"]
1040+
files = [file for glob in globs for file in path.glob(glob)]
10401041
if not files:
10411042
# Try the PyTorch patterns too, with lower priority
10421043
globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"]
@@ -1123,7 +1124,7 @@ def main(args_in: list[str] | None = None) -> None:
11231124
parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
11241125
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
11251126
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
1126-
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
1127+
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin, *.safetensors)")
11271128
parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm")
11281129
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
11291130
parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)

examples/benchmark/benchmark-matmult.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ int main(int argc, char ** argv) {
171171
struct ggml_tensor * m11xm2 = ggml_mul_mat(ctx, m11, m2);
172172

173173
// printf("Creating compute graph\n");
174-
struct ggml_cgraph gf = ggml_build_forward(m11xm2);
174+
struct ggml_cgraph * gf = ggml_new_graph(ctx);
175+
ggml_build_forward_expand(gf, m11xm2);
175176

176177
printf("n_threads=%i\n", benchmark_params.n_threads);
177178

@@ -180,9 +181,9 @@ int main(int argc, char ** argv) {
180181

181182
std::vector<uint8_t> work_buffer;
182183

183-
ggml_graph_compute_helper(work_buffer, &gf, benchmark_params.n_threads);
184+
ggml_graph_compute_helper(work_buffer, gf, benchmark_params.n_threads);
184185

185-
TENSOR_DUMP(gf.nodes[0]);
186+
TENSOR_DUMP(gf->nodes[0]);
186187

187188
printf("\n------ Test 2 - Matrix Mult via %s code\n", ggml_type_name(qtype));
188189

@@ -200,7 +201,8 @@ int main(int argc, char ** argv) {
200201
struct ggml_tensor * q31 = ggml_mul_mat(ctx, q11, m2);
201202

202203
// printf("Creating compute graph\n");
203-
struct ggml_cgraph gf31 = ggml_build_forward(q31);
204+
struct ggml_cgraph * gf31 = ggml_new_graph(ctx);
205+
ggml_build_forward_expand(gf31, q31);
204206

205207
// Set up a second graph computation to make sure we override the CPU cache lines
206208
// printf("Creating new tensor q12 & Running quantize\n");
@@ -211,7 +213,8 @@ int main(int argc, char ** argv) {
211213
struct ggml_tensor * q32 = ggml_mul_mat(ctx, q12, m2);
212214

213215
//printf("Creating compute graph\n");
214-
struct ggml_cgraph gf32 = ggml_build_forward(q32);
216+
struct ggml_cgraph * gf32 = ggml_new_graph(ctx);
217+
ggml_build_forward_expand(gf32, q32);
215218
printf("n_threads=%i\n", benchmark_params.n_threads);
216219

217220
const int dimx = sizex;
@@ -223,7 +226,7 @@ int main(int argc, char ** argv) {
223226

224227

225228
// Let's use the F32 result from above as a reference for the quantized multiplication
226-
float sum_of_F32_reference = tensor_sum_elements(gf.nodes[0]);
229+
float sum_of_F32_reference = tensor_sum_elements(gf->nodes[0]);
227230

228231
printf("Iteration;NThreads; SizeX; SizeY; SizeZ; Required_FLOPS; Elapsed_u_Seconds; gigaFLOPS\n");
229232
printf("=====================================================================================\n");
@@ -233,7 +236,7 @@ int main(int argc, char ** argv) {
233236

234237
long long int start = ggml_time_us();
235238
//printf("Running ggml_graph_compute\n");
236-
ggml_graph_compute_helper(work_buffer, &gf31, benchmark_params.n_threads);
239+
ggml_graph_compute_helper(work_buffer, gf31, benchmark_params.n_threads);
237240

238241
long long int stop = ggml_time_us();
239242
long long int usec = stop-start;
@@ -251,7 +254,7 @@ int main(int argc, char ** argv) {
251254

252255
// Check that the matrix multiplication result is in the right ballpark
253256
// We cannot use the exact value from the F32 multiplication because the quantizuation will be slightly different
254-
float sum_of_Q4_result = tensor_sum_elements(gf31.nodes[0]);
257+
float sum_of_Q4_result = tensor_sum_elements(gf31->nodes[0]);
255258
float delta = std::abs(sum_of_Q4_result - sum_of_F32_reference);
256259
float allowed_delta = (sum_of_F32_reference) / 1000 / 1000; // Let's accept an epsilon of 10^-6
257260

@@ -266,7 +269,7 @@ int main(int argc, char ** argv) {
266269
}
267270

268271
// Running a different graph computation to make sure we override the CPU cache lines
269-
ggml_graph_compute_helper(work_buffer, &gf32, benchmark_params.n_threads);
272+
ggml_graph_compute_helper(work_buffer, gf32, benchmark_params.n_threads);
270273
}
271274
printf("\n");
272275
printf("Average%78.2f\n",gflops_sum/((double)benchmark_params.n_iterations));

examples/export-lora/export-lora.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ static struct lora_data * load_lora(struct lora_info * info) {
240240
}
241241

242242
struct ggml_init_params params_ggml;
243-
params_ggml.mem_size = ggml_tensor_overhead() * GGML_MAX_NODES;
243+
params_ggml.mem_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE;
244244
params_ggml.mem_buffer = NULL;
245245
params_ggml.no_alloc = true;
246246
result->ctx = ggml_init(params_ggml);
@@ -334,7 +334,7 @@ static bool apply_lora(struct ggml_tensor * tensor, struct lora_data * lora, int
334334
float scaling = lora->info.scale * (float)lora->lora_alpha / (float)lora->lora_r;
335335

336336
struct ggml_init_params params;
337-
params.mem_size = GGML_OBJECT_SIZE + GGML_GRAPH_SIZE + ggml_tensor_overhead()*4 + GGML_MEM_ALIGN*5;
337+
params.mem_size = GGML_OBJECT_SIZE + ggml_graph_overhead() + ggml_tensor_overhead()*4 + GGML_MEM_ALIGN*5;
338338
params.mem_buffer = NULL;
339339
params.no_alloc = true;
340340
struct ggml_context * ctx = NULL;

examples/finetune/finetune.cpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
772772
if (enable_checkpointing) {
773773
ggml_build_backward_gradient_checkpointing(ctx, gf, gb, gb_tmp, checkpoints.data(), (int) checkpoints.size());
774774
} else {
775-
*gb = *gf;
775+
ggml_graph_cpy(gf, gb);
776776
ggml_build_backward_expand(ctx, gf, gb, true);
777777
}
778778

@@ -1615,6 +1615,7 @@ int main(int argc, char ** argv) {
16151615
opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
16161616
opt->params.print_forward_graph = false;
16171617
opt->params.print_backward_graph = false;
1618+
opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
16181619
opt->params.n_threads = params.common.n_threads;
16191620
opt->params.past = params.common.opt_past;
16201621
opt->params.delta = params.common.opt_delta;
@@ -1741,11 +1742,9 @@ int main(int argc, char ** argv) {
17411742
ggml_allocr_free(alloc);
17421743

17431744
// context for compute tensors without their data
1744-
size_t estimated_compute_size_wo_data = (
1745-
ggml_tensor_overhead()*GGML_MAX_NODES*2
1746-
+ (GGML_OBJECT_SIZE+GGML_GRAPH_SIZE)*(
1747-
params.common.use_checkpointing ? 3 : 2
1748-
)
1745+
const size_t estimated_compute_size_wo_data = (
1746+
2*LLAMA_TRAIN_MAX_NODES*ggml_tensor_overhead() +
1747+
(params.common.use_checkpointing ? 3 : 2)*(GGML_OBJECT_SIZE+ggml_graph_overhead_custom(LLAMA_TRAIN_MAX_NODES, true))
17491748
);
17501749
struct ggml_init_params ctx_compute_params = {
17511750
estimated_compute_size_wo_data, // mem_size
@@ -1768,11 +1767,11 @@ int main(int argc, char ** argv) {
17681767
for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
17691768
ctx_compute = ggml_init(ctx_compute_params);
17701769
alloc = ggml_allocr_new_measure(tensor_alignment);
1771-
gf = ggml_new_graph(ctx_compute);
1770+
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
17721771
gf->order = (enum ggml_cgraph_eval_order) order;
1773-
gb = ggml_new_graph(ctx_compute);
1772+
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
17741773
gb_tmp = params.common.use_checkpointing
1775-
? ggml_new_graph(ctx_compute)
1774+
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
17761775
: NULL;
17771776
loss = llama_build_lora_finetune_graphs(
17781777
&model, &lora, alloc, ctx_compute,
@@ -1801,11 +1800,11 @@ int main(int argc, char ** argv) {
18011800
mem_compute_data.resize(max_compute_size);
18021801
ctx_compute = ggml_init(ctx_compute_params);
18031802
alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
1804-
gf = ggml_new_graph(ctx_compute);
1803+
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
18051804
gf->order = best_order;
1806-
gb = ggml_new_graph(ctx_compute);
1805+
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
18071806
gb_tmp = params.common.use_checkpointing
1808-
? ggml_new_graph(ctx_compute)
1807+
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
18091808
: NULL;
18101809
loss = llama_build_lora_finetune_graphs(
18111810
&model, &lora, alloc, ctx_compute,

examples/llava/clip.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
664664
// measure mem requirement and allocate
665665
{
666666
static const size_t tensor_alignment = 32;
667-
new_clip->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
667+
new_clip->buf_compute.resize(ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead());
668668
new_clip->alloc = ggml_allocr_new_measure(tensor_alignment);
669669
clip_image_f32_batch batch;
670670
batch.size = 1;
@@ -761,7 +761,7 @@ bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip
761761
temp->ny = img->ny;
762762
temp->size = img->size;
763763
temp->data = new uint8_t[temp->size]();
764-
*temp->data = *img->data; // copy
764+
memcpy(&temp->data[0], &img->data[0], temp->size); // copy
765765
}
766766

767767
const int nx = temp->nx;

examples/metal/metal.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ int main(int argc, char ** argv) {
3434
struct ggml_context * ctx_data = NULL;
3535
struct ggml_context * ctx_eval = NULL;
3636

37-
struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
37+
struct ggml_cgraph * gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
3838

3939
// this allocates all Metal resources and memory buffers
4040
auto * ctx_metal = ggml_metal_init(1);
@@ -46,21 +46,21 @@ int main(int argc, char ** argv) {
4646

4747
// main
4848
{
49-
struct ggml_tensor * input = ggml_graph_get_tensor(&gf, "embd");
49+
struct ggml_tensor * input = ggml_graph_get_tensor(gf, "embd");
5050
*(int32_t *) input->data = 1; // BOS
5151

5252
ggml_metal_set_tensor(ctx_metal, input);
5353

5454
// warmup
55-
ggml_metal_graph_compute(ctx_metal, &gf);
55+
ggml_metal_graph_compute(ctx_metal, gf);
5656

5757
const int n_iter = 16;
5858

5959
const int64_t t0 = ggml_time_us();
6060

6161
// the actual inference happens here
6262
for (int i = 0; i < n_iter; ++i) {
63-
ggml_metal_graph_compute(ctx_metal, &gf);
63+
ggml_metal_graph_compute(ctx_metal, gf);
6464
}
6565

6666
const int64_t t1 = ggml_time_us();
@@ -70,7 +70,7 @@ int main(int argc, char ** argv) {
7070

7171
// debug output
7272
{
73-
struct ggml_tensor * logits = gf.nodes[gf.n_nodes - 1];
73+
struct ggml_tensor * logits = gf->nodes[gf->n_nodes - 1];
7474
ggml_metal_get_tensor(ctx_metal, logits);
7575

7676
float * ptr = (float *) ggml_get_data(logits);

0 commit comments

Comments
 (0)