@@ -2566,6 +2566,85 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
25662566 return nread;
25672567}
25682568
2569+ bool llama_load_session_file (struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
2570+ llama_file file (path_session, " rb" );
2571+
2572+ // sanity checks
2573+ {
2574+ const uint32_t magic = file.read_u32 ();
2575+ const uint32_t version = file.read_u32 ();
2576+
2577+ if (!(magic == LLAMA_SESSION_MAGIC && version == LLAMA_SESSION_VERSION)) {
2578+ fprintf (stderr, " %s : unknown (magic, version) for session file: %08x, %08x\n " , __func__, magic, version);
2579+ return false ;
2580+ }
2581+
2582+ llama_hparams session_hparams;
2583+ file.read_raw (&session_hparams, sizeof (llama_hparams));
2584+
2585+ if (session_hparams != ctx->model .hparams ) {
2586+ fprintf (stderr, " %s : model hparams didn't match from session file!\n " , __func__);
2587+ return false ;
2588+ }
2589+ }
2590+
2591+ // load the prompt
2592+ {
2593+ const uint32_t n_token_count = file.read_u32 ();
2594+
2595+ if (n_token_count > n_token_capacity) {
2596+ fprintf (stderr, " %s : token count in session file exceeded capacity! %u > %zu\n " , __func__, n_token_count, n_token_capacity);
2597+ return false ;
2598+ }
2599+
2600+ file.read_raw (tokens_out, sizeof (llama_token) * n_token_count);
2601+ *n_token_count_out = n_token_count;
2602+ }
2603+
2604+ // restore the context state
2605+ {
2606+ const size_t n_state_size_cur = file.size - file.tell ();
2607+ const size_t n_state_size_exp = llama_get_state_size (ctx);
2608+
2609+ if (n_state_size_cur != n_state_size_exp) {
2610+ fprintf (stderr, " %s : the state size in session file didn't match! expected %zu, got %zu\n " , __func__, n_state_size_exp, n_state_size_cur);
2611+ return false ;
2612+ }
2613+
2614+ std::vector<uint8_t > state_data (n_state_size_cur);
2615+ file.read_raw (state_data.data (), n_state_size_cur);
2616+
2617+ llama_set_state_data (ctx, state_data.data ());
2618+ }
2619+
2620+ return true ;
2621+ }
2622+
2623+ bool llama_save_session_file (struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
2624+ llama_file file (path_session, " wb" );
2625+
2626+ file.write_u32 (LLAMA_SESSION_MAGIC);
2627+ file.write_u32 (LLAMA_SESSION_VERSION);
2628+
2629+ file.write_raw (&ctx->model .hparams , sizeof (llama_hparams));
2630+
2631+ // save the prompt
2632+ file.write_u32 ((uint32_t ) n_token_count);
2633+ file.write_raw (tokens, sizeof (llama_token) * n_token_count);
2634+
2635+ // save the context state
2636+ {
2637+ const size_t n_state_size = llama_get_state_size (ctx);
2638+
2639+ std::vector<uint8_t > state_data (n_state_size);
2640+ llama_copy_state_data (ctx, state_data.data ());
2641+
2642+ file.write_raw (state_data.data (), n_state_size);
2643+ }
2644+
2645+ return true ;
2646+ }
2647+
25692648int llama_eval (
25702649 struct llama_context * ctx,
25712650 const llama_token * tokens,
@@ -2693,57 +2772,3 @@ const char * llama_print_system_info(void) {
26932772std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map (struct llama_context * ctx) {
26942773 return ctx->model .tensors_by_name ;
26952774}
2696-
2697- size_t llama_load_session_file (struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
2698- // TODO leverage mmap
2699- llama_file file (path_session, " rb" );
2700- const uint32_t magic = file.read_u32 ();
2701- const uint32_t version = file.read_u32 ();
2702-
2703- if (!(magic == ' ggsn' && version == 0 )) {
2704- fprintf (stderr, " %s : unknown (magic, version) for session file: %08x, %08x\n " , __func__, magic, version);
2705- return 0 ;
2706- }
2707-
2708- llama_hparams session_hparams;
2709- file.read_raw (&session_hparams, sizeof (llama_hparams));
2710-
2711- // REVIEW
2712- if (session_hparams != ctx->model .hparams ) {
2713- fprintf (stderr, " %s : model hparams didn't match from session file!\n " , __func__);
2714- return 0 ;
2715- }
2716-
2717- const uint32_t n_token_count = file.read_u32 ();
2718- LLAMA_ASSERT (n_token_capacity >= n_token_count);
2719- file.read_raw (tokens_out, sizeof (llama_token) * n_token_count);
2720- *n_token_count_out = n_token_count;
2721-
2722- const size_t n_state_size = file.size - file.tell ();
2723- const size_t n_orig_state_size = llama_get_state_size (ctx);
2724- if (n_state_size != n_orig_state_size) {
2725- fprintf (stderr, " %s : failed to validate state size\n " , __func__);
2726- }
2727- std::unique_ptr<uint8_t []> state_data (new uint8_t [n_state_size]);
2728- file.read_raw (state_data.get (), n_state_size);
2729- return llama_set_state_data (ctx, state_data.get ());
2730- }
2731-
2732- size_t llama_save_session_file (struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
2733- // TODO save temp & swap
2734- llama_file file (path_session, " wb" );
2735-
2736- const size_t n_state_size = llama_get_state_size (ctx);
2737- std::unique_ptr<uint8_t []> state_data (new uint8_t [n_state_size]);
2738- llama_copy_state_data (ctx, state_data.get ());
2739-
2740- file.write_u32 (' ggsn' ); // magic
2741- file.write_u32 (0 ); // version
2742- file.write_raw (&ctx->model .hparams , sizeof (llama_hparams));
2743-
2744- file.write_u32 ((uint32_t ) n_token_count); // REVIEW
2745- file.write_raw (tokens, sizeof (llama_token) * n_token_count);
2746-
2747- file.write_raw (state_data.get (), n_state_size);
2748- return n_state_size; // REVIEW
2749- }
0 commit comments