diff --git a/.github/workflows/image-pr.yml b/.github/workflows/image-pr.yml index ae8bd070a125..527a8479ee39 100644 --- a/.github/workflows/image-pr.yml +++ b/.github/workflows/image-pr.yml @@ -51,6 +51,14 @@ jobs: image-type: 'extras' runs-on: 'arc-runner-set' base-image: "ubuntu:22.04" + - build-type: 'hipblas' + platforms: 'linux/amd64' + tag-latest: 'false' + tag-suffix: '-hipblas' + ffmpeg: 'false' + image-type: 'extras' + base-image: "rocm/dev-ubuntu-22.04:6.0-complete" + runs-on: 'arc-runner-set' core-image-build: uses: ./.github/workflows/image_build.yml with: diff --git a/.github/workflows/image.yml b/.github/workflows/image.yml index ac61deeca6e8..830528a1a18a 100644 --- a/.github/workflows/image.yml +++ b/.github/workflows/image.yml @@ -103,6 +103,22 @@ jobs: image-type: 'extras' base-image: "ubuntu:22.04" runs-on: 'arc-runner-set' + - build-type: 'hipblas' + platforms: 'linux/amd64' + tag-latest: 'false' + tag-suffix: '-hipblas-ffmpeg' + ffmpeg: 'true' + image-type: 'extras' + base-image: "rocm/dev-ubuntu-22.04:6.0-complete" + runs-on: 'arc-runner-set' + - build-type: 'hipblas' + platforms: 'linux/amd64' + tag-latest: 'false' + tag-suffix: '-hipblas' + ffmpeg: 'false' + image-type: 'extras' + base-image: "rocm/dev-ubuntu-22.04:6.0-complete" + runs-on: 'arc-runner-set' core-image-build: uses: ./.github/workflows/image_build.yml with: @@ -124,6 +140,22 @@ jobs: strategy: matrix: include: + - build-type: 'hipblas' + platforms: 'linux/amd64' + tag-latest: 'false' + tag-suffix: '-hipblas-ffmpeg-core' + ffmpeg: 'true' + image-type: 'core' + base-image: "rocm/dev-ubuntu-22.04:6.0-complete" + runs-on: 'arc-runner-set' + - build-type: 'hipblas' + platforms: 'linux/amd64' + tag-latest: 'false' + tag-suffix: '-hipblas-core' + ffmpeg: 'false' + image-type: 'core' + base-image: "rocm/dev-ubuntu-22.04:6.0-complete" + runs-on: 'arc-runner-set' - build-type: '' platforms: 'linux/amd64' tag-latest: 'false' diff --git a/Dockerfile b/Dockerfile index 6c5e27457382..a04a866ec7d1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,3 @@ -ARG GO_VERSION=1.21 ARG IMAGE_TYPE=extras ARG BASE_IMAGE=ubuntu:22.04 @@ -42,8 +41,12 @@ RUN if [ "${BUILD_TYPE}" = "cublas" ]; then \ apt-get install -y cuda-nvcc-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcurand-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcublas-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcusparse-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} libcusolver-dev-${CUDA_MAJOR_VERSION}-${CUDA_MINOR_VERSION} && apt-get clean \ ; fi +# Cuda ENV PATH /usr/local/cuda/bin:${PATH} +# HipBLAS requirements +ENV PATH /opt/rocm/bin:${PATH} + # OpenBLAS requirements and stable diffusion RUN apt-get install -y \ libopenblas-dev \ @@ -70,7 +73,9 @@ RUN curl https://repo.anaconda.com/pkgs/misc/gpgkeys/anaconda.asc | gpg --dearmo apt-get install -y conda && apt-get clean ENV PATH="/root/.cargo/bin:${PATH}" +RUN apt-get install -y python3-pip && apt-get clean RUN pip install --upgrade pip + RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y RUN apt-get install -y espeak-ng espeak && apt-get clean diff --git a/Makefile b/Makefile index c63d46f86662..71ca6fcffdc8 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ GOLLAMA_VERSION?=aeba71ee842819da681ea537e78846dc75949ac0 GOLLAMA_STABLE_VERSION?=50cee7712066d9e38306eccadcfbb44ea87df4b7 -CPPLLAMA_VERSION?=f026f8120f97090d34a52b3dc023c82e0ede3f7d +CPPLLAMA_VERSION?=fd43d66f46ee3b5345fb8a74a252d86ccd34a409 # gpt4all version GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all @@ -97,6 +97,8 @@ endif ifeq ($(BUILD_TYPE),hipblas) ROCM_HOME ?= /opt/rocm + ROCM_PATH ?= /opt/rocm + LD_LIBRARY_PATH ?= /opt/rocm/lib:/opt/rocm/llvm/lib export CXX=$(ROCM_HOME)/llvm/bin/clang++ export CC=$(ROCM_HOME)/llvm/bin/clang # llama-ggml has no hipblas support, so override it here. @@ -105,7 +107,7 @@ ifeq ($(BUILD_TYPE),hipblas) GPU_TARGETS ?= gfx900,gfx90a,gfx1030,gfx1031,gfx1100 AMDGPU_TARGETS ?= "$(GPU_TARGETS)" CMAKE_ARGS+=-DLLAMA_HIPBLAS=ON -DAMDGPU_TARGETS="$(AMDGPU_TARGETS)" -DGPU_TARGETS="$(GPU_TARGETS)" - CGO_LDFLAGS += -O3 --rtlib=compiler-rt -unwindlib=libgcc -lhipblas -lrocblas --hip-link + CGO_LDFLAGS += -O3 --rtlib=compiler-rt -unwindlib=libgcc -lhipblas -lrocblas --hip-link -L${ROCM_HOME}/lib/llvm/lib endif ifeq ($(BUILD_TYPE),metal) diff --git a/README.md b/README.md index fa875e5a1f1e..2ae95d8cfb34 100644 --- a/README.md +++ b/README.md @@ -43,20 +43,23 @@ [Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) +- Parallel function calling: https://github.com/mudler/LocalAI/pull/1726 +- Upload file API: https://github.com/mudler/LocalAI/pull/1703 +- Tools API support: https://github.com/mudler/LocalAI/pull/1715 +- LLaVa 1.6: https://github.com/mudler/LocalAI/pull/1714 +- ROCm container images: https://github.com/mudler/LocalAI/pull/1595 - Intel GPU support (sycl): https://github.com/mudler/LocalAI/issues/1653 - Deprecation of old backends: https://github.com/mudler/LocalAI/issues/1651 - Mamba support: https://github.com/mudler/LocalAI/pull/1589 - Start and share models with config file: https://github.com/mudler/LocalAI/pull/1522 - 🐸 Coqui: https://github.com/mudler/LocalAI/pull/1489 -- Inline templates: https://github.com/mudler/LocalAI/pull/1452 -- Mixtral: https://github.com/mudler/LocalAI/pull/1449 - Img2vid https://github.com/mudler/LocalAI/pull/1442 -- Musicgen https://github.com/mudler/LocalAI/pull/1387 Hot topics (looking for contributors): - Backends v2: https://github.com/mudler/LocalAI/issues/1126 - Improving UX v2: https://github.com/mudler/LocalAI/issues/1373 - +- Assistant API: https://github.com/mudler/LocalAI/issues/1273 + If you want to help and contribute, issues up for grabs: https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3A%22up+for+grabs%22 ## 💻 [Getting started](https://localai.io/basics/getting_started/index.html) @@ -95,9 +98,8 @@ WebUIs: Model galleries - https://github.com/go-skynet/model-gallery -Auto Docker / Model setup -- https://io.midori-ai.xyz/howtos/easy-localai-installer/ -- https://io.midori-ai.xyz/howtos/easy-model-installer/ +UI / Management Programs +- [LocalAI Manager](https://io.midori-ai.xyz/howtos/easy-model-installer/) Other: - Helm chart https://github.com/go-skynet/helm-charts diff --git a/api/localai/backend_monitor.go b/api/localai/backend_monitor.go index 8cb0bb45ed14..e6f1b409b142 100644 --- a/api/localai/backend_monitor.go +++ b/api/localai/backend_monitor.go @@ -5,10 +5,10 @@ import ( "fmt" "strings" - config "github.com/go-skynet/LocalAI/api/config" + config "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/core/options" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" diff --git a/api/localai/gallery.go b/api/localai/gallery.go index a2ad5bd1ac46..ee6f4d7d4a12 100644 --- a/api/localai/gallery.go +++ b/api/localai/gallery.go @@ -11,7 +11,7 @@ import ( json "github.com/json-iterator/go" "gopkg.in/yaml.v3" - config "github.com/go-skynet/LocalAI/api/config" + config "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/utils" diff --git a/api/localai/localai.go b/api/localai/localai.go index 3abe440ea12e..9d5bbf6c5ca7 100644 --- a/api/localai/localai.go +++ b/api/localai/localai.go @@ -1,12 +1,12 @@ package localai import ( - "github.com/go-skynet/LocalAI/api/backend" - config "github.com/go-skynet/LocalAI/api/config" fiberContext "github.com/go-skynet/LocalAI/api/ctx" + "github.com/go-skynet/LocalAI/core/backend" + config "github.com/go-skynet/LocalAI/core/config" "github.com/rs/zerolog/log" - "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/core/options" "github.com/gofiber/fiber/v2" ) diff --git a/api/openai/chat.go b/api/openai/chat.go index 819cd6b2d6c4..78d02f96652e 100644 --- a/api/openai/chat.go +++ b/api/openai/chat.go @@ -8,10 +8,10 @@ import ( "strings" "time" - "github.com/go-skynet/LocalAI/api/backend" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/api/schema" + "github.com/go-skynet/LocalAI/core/backend" + config "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/pkg/grammar" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/utils" @@ -55,6 +55,102 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) }) close(responses) } + processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { + result := "" + _, tokenUsage, _ := ComputeChoices(req, prompt, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { + result += s + // TODO: Change generated BNF grammar to be compliant with the schema so we can + // stream the result token by token here. + return true + }) + + results := parseFunctionCall(result, config.FunctionsConfig.ParallelCalls) + noActionToRun := len(results) > 0 && results[0].name == noAction + + switch { + case noActionToRun: + initialMessage := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}}, + Object: "chat.completion.chunk", + } + responses <- initialMessage + + result, err := handleQuestion(config, req, o, results[0].arguments, prompt) + if err != nil { + log.Error().Msgf("error handling question: %s", err.Error()) + return + } + + resp := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0}}, + Object: "chat.completion.chunk", + Usage: schema.OpenAIUsage{ + PromptTokens: tokenUsage.Prompt, + CompletionTokens: tokenUsage.Completion, + TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, + }, + } + + responses <- resp + + default: + for i, ss := range results { + name, args := ss.name, ss.arguments + + initialMessage := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{ + Delta: &schema.Message{ + Role: "assistant", + ToolCalls: []schema.ToolCall{ + { + Index: i, + ID: id, + Type: "function", + FunctionCall: schema.FunctionCall{ + Name: name, + }, + }, + }, + }}}, + Object: "chat.completion.chunk", + } + responses <- initialMessage + + responses <- schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{ + Delta: &schema.Message{ + Role: "assistant", + ToolCalls: []schema.ToolCall{ + { + Index: i, + ID: id, + Type: "function", + FunctionCall: schema.FunctionCall{ + Arguments: args, + }, + }, + }, + }}}, + Object: "chat.completion.chunk", + } + } + } + + close(responses) + } + return func(c *fiber.Ctx) error { processFunctions := false funcs := grammar.Functions{} @@ -116,13 +212,13 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) // Update input grammar jsStruct := funcs.ToJSONStructure() - config.Grammar = jsStruct.Grammar("") + config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls) } else if input.JSONFunctionGrammarObject != nil { - config.Grammar = input.JSONFunctionGrammarObject.Grammar("") + config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls) } // functions are not supported in stream mode (yet?) - toStream := input.Stream && !processFunctions + toStream := input.Stream log.Debug().Msgf("Parameters: %+v", config) @@ -145,6 +241,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) } r := config.Roles[role] contentExists := i.Content != nil && i.StringContent != "" + // First attempt to populate content via a chat message specific template if config.TemplateConfig.ChatMessage != "" { chatMessageData := model.ChatMessageTemplateData{ @@ -152,6 +249,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) Role: r, RoleName: role, Content: i.StringContent, + FunctionName: i.Name, MessageIndex: messageIndex, } templatedChatMessage, err := o.Loader.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData) @@ -254,17 +352,24 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) log.Debug().Msgf("Grammar: %+v", config.Grammar) } - if toStream { + switch { + case toStream: responses := make(chan schema.OpenAIResponse) - go process(predInput, input, config, o.Loader, responses) + if !processFunctions { + go process(predInput, input, config, o.Loader, responses) + } else { + go processTools(noActionName, predInput, input, config, o.Loader, responses) + } c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { - usage := &schema.OpenAIUsage{} - + toolsCalled := false for ev := range responses { usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it + if len(ev.Choices[0].Delta.ToolCalls) > 0 { + toolsCalled = true + } var buf bytes.Buffer enc := json.NewEncoder(&buf) enc.Encode(ev) @@ -278,13 +383,20 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) w.Flush() } + finishReason := "stop" + if toolsCalled { + finishReason = "tool_calls" + } else if toolsCalled && len(input.Tools) == 0 { + finishReason = "function_call" + } + resp := &schema.OpenAIResponse{ ID: id, Created: created, Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []schema.Choice{ { - FinishReason: "stop", + FinishReason: finishReason, Index: 0, Delta: &schema.Message{Content: &emptyMessage}, }}, @@ -298,102 +410,182 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) w.Flush() })) return nil - } - result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) { - if processFunctions { - // As we have to change the result before processing, we can't stream the answer (yet?) - ss := map[string]interface{}{} - // This prevent newlines to break JSON parsing for clients - s = utils.EscapeNewLines(s) - json.Unmarshal([]byte(s), &ss) - log.Debug().Msgf("Function return: %s %+v", s, ss) - - // The grammar defines the function name as "function", while OpenAI returns "name" - func_name := ss["function"] - // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object - args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) - d, _ := json.Marshal(args) - - ss["arguments"] = string(d) - ss["name"] = func_name - - // if do nothing, reply with a message - if func_name == noActionName { - log.Debug().Msgf("nothing to do, computing a reply") - - // If there is a message that the LLM already sends as part of the JSON reply, use it - arguments := map[string]interface{}{} - json.Unmarshal([]byte(d), &arguments) - m, exists := arguments["message"] - if exists { - switch message := m.(type) { - case string: - if message != "" { - log.Debug().Msgf("Reply received from LLM: %s", message) - message = backend.Finetune(*config, predInput, message) - log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) - - *c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &message}}) - return - } - } - } + // no streaming mode + default: + result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) { + if !processFunctions { + // no function is called, just reply and use stop as finish reason + *c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}}) + return + } - log.Debug().Msgf("No action received from LLM, without a message, computing a reply") - // Otherwise ask the LLM to understand the JSON output and the context, and return a message - // Note: This costs (in term of CPU) another computation - config.Grammar = "" - images := []string{} - for _, m := range input.Messages { - images = append(images, m.StringImages...) - } - predFunc, err := backend.ModelInference(input.Context, predInput, images, o.Loader, *config, o, nil) + results := parseFunctionCall(s, config.FunctionsConfig.ParallelCalls) + noActionsToRun := len(results) > 0 && results[0].name == noActionName + + switch { + case noActionsToRun: + result, err := handleQuestion(config, input, o, results[0].arguments, predInput) if err != nil { - log.Error().Msgf("inference error: %s", err.Error()) + log.Error().Msgf("error handling question: %s", err.Error()) return } + *c = append(*c, schema.Choice{ + Message: &schema.Message{Role: "assistant", Content: &result}}) + default: + toolChoice := schema.Choice{ + Message: &schema.Message{ + Role: "assistant", + }, + } - prediction, err := predFunc() - if err != nil { - log.Error().Msgf("inference error: %s", err.Error()) - return + if len(input.Tools) > 0 { + toolChoice.FinishReason = "tool_calls" } - fineTunedResponse := backend.Finetune(*config, predInput, prediction.Response) - *c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &fineTunedResponse}}) - } else { - // otherwise reply with the function call - *c = append(*c, schema.Choice{ - FinishReason: "function_call", - Message: &schema.Message{Role: "assistant", FunctionCall: ss}, - }) + for _, ss := range results { + name, args := ss.name, ss.arguments + if len(input.Tools) > 0 { + // If we are using tools, we condense the function calls into + // a single response choice with all the tools + toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls, + schema.ToolCall{ + ID: id, + Type: "function", + FunctionCall: schema.FunctionCall{ + Name: name, + Arguments: args, + }, + }, + ) + } else { + // otherwise we return more choices directly + *c = append(*c, schema.Choice{ + FinishReason: "function_call", + Message: &schema.Message{ + Role: "assistant", + FunctionCall: map[string]interface{}{ + "name": name, + "arguments": args, + }, + }, + }) + } + } + + if len(input.Tools) > 0 { + // we need to append our result if we are using tools + *c = append(*c, toolChoice) + } } - return + }, nil) + if err != nil { + return err } - *c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}}) - }, nil) - if err != nil { - return err + + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "chat.completion", + Usage: schema.OpenAIUsage{ + PromptTokens: tokenUsage.Prompt, + CompletionTokens: tokenUsage.Completion, + TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, + }, + } + respData, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", respData) + + // Return the prediction in the response body + return c.JSON(resp) } - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: result, - Object: "chat.completion", - Usage: schema.OpenAIUsage{ - PromptTokens: tokenUsage.Prompt, - CompletionTokens: tokenUsage.Completion, - TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, - }, + } +} + +func handleQuestion(config *config.Config, input *schema.OpenAIRequest, o *options.Option, args, prompt string) (string, error) { + log.Debug().Msgf("nothing to do, computing a reply") + + // If there is a message that the LLM already sends as part of the JSON reply, use it + arguments := map[string]interface{}{} + json.Unmarshal([]byte(args), &arguments) + m, exists := arguments["message"] + if exists { + switch message := m.(type) { + case string: + if message != "" { + log.Debug().Msgf("Reply received from LLM: %s", message) + message = backend.Finetune(*config, prompt, message) + log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) + + return message, nil + } } - respData, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", respData) + } - // Return the prediction in the response body - return c.JSON(resp) + log.Debug().Msgf("No action received from LLM, without a message, computing a reply") + // Otherwise ask the LLM to understand the JSON output and the context, and return a message + // Note: This costs (in term of CPU/GPU) another computation + config.Grammar = "" + images := []string{} + for _, m := range input.Messages { + images = append(images, m.StringImages...) } + + predFunc, err := backend.ModelInference(input.Context, prompt, images, o.Loader, *config, o, nil) + if err != nil { + log.Error().Msgf("inference error: %s", err.Error()) + return "", err + } + + prediction, err := predFunc() + if err != nil { + log.Error().Msgf("inference error: %s", err.Error()) + return "", err + } + return backend.Finetune(*config, prompt, prediction.Response), nil +} + +type funcCallResults struct { + name string + arguments string +} + +func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults { + results := []funcCallResults{} + + // TODO: use generics to avoid this code duplication + if multipleResults { + ss := []map[string]interface{}{} + s := utils.EscapeNewLines(llmresult) + json.Unmarshal([]byte(s), &ss) + log.Debug().Msgf("Function return: %s %+v", s, ss) + + for _, s := range ss { + func_name := s["function"] + args := s["arguments"] + d, _ := json.Marshal(args) + results = append(results, funcCallResults{name: func_name.(string), arguments: string(d)}) + } + } else { + // As we have to change the result before processing, we can't stream the answer token-by-token (yet?) + ss := map[string]interface{}{} + // This prevent newlines to break JSON parsing for clients + s := utils.EscapeNewLines(llmresult) + json.Unmarshal([]byte(s), &ss) + log.Debug().Msgf("Function return: %s %+v", s, ss) + + // The grammar defines the function name as "function", while OpenAI returns "name" + func_name := ss["function"] + // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object + args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) + d, _ := json.Marshal(args) + + results = append(results, funcCallResults{name: func_name.(string), arguments: string(d)}) + } + + return results } diff --git a/api/openai/completion.go b/api/openai/completion.go index b098451da199..af56625e324b 100644 --- a/api/openai/completion.go +++ b/api/openai/completion.go @@ -8,10 +8,10 @@ import ( "fmt" "time" - "github.com/go-skynet/LocalAI/api/backend" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/api/schema" + "github.com/go-skynet/LocalAI/core/backend" + config "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/pkg/grammar" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" diff --git a/api/openai/edit.go b/api/openai/edit.go index 16679ae51fad..56b17920d27a 100644 --- a/api/openai/edit.go +++ b/api/openai/edit.go @@ -5,10 +5,10 @@ import ( "fmt" "time" - "github.com/go-skynet/LocalAI/api/backend" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/api/schema" + "github.com/go-skynet/LocalAI/core/backend" + config "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/core/schema" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/google/uuid" diff --git a/api/openai/embeddings.go b/api/openai/embeddings.go index 44feb373d542..198493e1805b 100644 --- a/api/openai/embeddings.go +++ b/api/openai/embeddings.go @@ -5,12 +5,12 @@ import ( "fmt" "time" - "github.com/go-skynet/LocalAI/api/backend" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/schema" + "github.com/go-skynet/LocalAI/core/backend" + config "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/schema" "github.com/google/uuid" - "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/core/options" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" ) diff --git a/api/openai/files.go b/api/openai/files.go new file mode 100644 index 000000000000..140b41519407 --- /dev/null +++ b/api/openai/files.go @@ -0,0 +1,218 @@ +package openai + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "time" + + config "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +var uploadedFiles []File + +const uploadedFilesFile = "uploadedFiles.json" + +// File represents the structure of a file object from the OpenAI API. +type File struct { + ID string `json:"id"` // Unique identifier for the file + Object string `json:"object"` // Type of the object (e.g., "file") + Bytes int `json:"bytes"` // Size of the file in bytes + CreatedAt time.Time `json:"created_at"` // The time at which the file was created + Filename string `json:"filename"` // The name of the file + Purpose string `json:"purpose"` // The purpose of the file (e.g., "fine-tune", "classifications", etc.) +} + +func saveUploadConfig(uploadDir string) { + file, err := json.MarshalIndent(uploadedFiles, "", " ") + if err != nil { + log.Error().Msgf("Failed to JSON marshal the uploadedFiles: %s", err) + } + + err = os.WriteFile(filepath.Join(uploadDir, uploadedFilesFile), file, 0644) + if err != nil { + log.Error().Msgf("Failed to save uploadedFiles to file: %s", err) + } +} + +func LoadUploadConfig(uploadPath string) { + uploadFilePath := filepath.Join(uploadPath, uploadedFilesFile) + + _, err := os.Stat(uploadFilePath) + if os.IsNotExist(err) { + log.Debug().Msgf("No uploadedFiles file found at %s", uploadFilePath) + return + } + + file, err := os.ReadFile(uploadFilePath) + if err != nil { + log.Error().Msgf("Failed to read file: %s", err) + } else { + err = json.Unmarshal(file, &uploadedFiles) + if err != nil { + log.Error().Msgf("Failed to JSON unmarshal the file into uploadedFiles: %s", err) + } + } +} + +// UploadFilesEndpoint https://platform.openai.com/docs/api-reference/files/create +func UploadFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + file, err := c.FormFile("file") + if err != nil { + return err + } + + // Check the file size + if file.Size > int64(o.UploadLimitMB*1024*1024) { + return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("File size %d exceeds upload limit %d", file.Size, o.UploadLimitMB)) + } + + purpose := c.FormValue("purpose", "") //TODO put in purpose dirs + if purpose == "" { + return c.Status(fiber.StatusBadRequest).SendString("Purpose is not defined") + } + + // Sanitize the filename to prevent directory traversal + filename := utils.SanitizeFileName(file.Filename) + + savePath := filepath.Join(o.UploadDir, filename) + + // Check if file already exists + if _, err := os.Stat(savePath); !os.IsNotExist(err) { + return c.Status(fiber.StatusBadRequest).SendString("File already exists") + } + + err = c.SaveFile(file, savePath) + if err != nil { + return c.Status(fiber.StatusInternalServerError).SendString("Failed to save file: " + err.Error()) + } + + f := File{ + ID: fmt.Sprintf("file-%d", time.Now().Unix()), + Object: "file", + Bytes: int(file.Size), + CreatedAt: time.Now(), + Filename: file.Filename, + Purpose: purpose, + } + + uploadedFiles = append(uploadedFiles, f) + saveUploadConfig(o.UploadDir) + return c.Status(fiber.StatusOK).JSON(f) + } +} + +// ListFilesEndpoint https://platform.openai.com/docs/api-reference/files/list +func ListFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + type ListFiles struct { + Data []File + Object string + } + + return func(c *fiber.Ctx) error { + var listFiles ListFiles + + purpose := c.Query("purpose") + if purpose == "" { + listFiles.Data = uploadedFiles + } else { + for _, f := range uploadedFiles { + if purpose == f.Purpose { + listFiles.Data = append(listFiles.Data, f) + } + } + } + listFiles.Object = "list" + return c.Status(fiber.StatusOK).JSON(listFiles) + } +} + +func getFileFromRequest(c *fiber.Ctx) (*File, error) { + id := c.Params("file_id") + if id == "" { + return nil, fmt.Errorf("file_id parameter is required") + } + + for _, f := range uploadedFiles { + if id == f.ID { + return &f, nil + } + } + + return nil, fmt.Errorf("unable to find file id %s", id) +} + +// GetFilesEndpoint https://platform.openai.com/docs/api-reference/files/retrieve +func GetFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + file, err := getFileFromRequest(c) + if err != nil { + return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) + } + + return c.JSON(file) + } +} + +// DeleteFilesEndpoint https://platform.openai.com/docs/api-reference/files/delete +func DeleteFilesEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + type DeleteStatus struct { + Id string + Object string + Deleted bool + } + + return func(c *fiber.Ctx) error { + file, err := getFileFromRequest(c) + if err != nil { + return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) + } + + err = os.Remove(filepath.Join(o.UploadDir, file.Filename)) + if err != nil { + // If the file doesn't exist then we should just continue to remove it + if !errors.Is(err, os.ErrNotExist) { + return c.Status(fiber.StatusInternalServerError).SendString(fmt.Sprintf("Unable to delete file: %s, %v", file.Filename, err)) + } + } + + // Remove upload from list + for i, f := range uploadedFiles { + if f.ID == file.ID { + uploadedFiles = append(uploadedFiles[:i], uploadedFiles[i+1:]...) + break + } + } + + saveUploadConfig(o.UploadDir) + return c.JSON(DeleteStatus{ + Id: file.ID, + Object: "file", + Deleted: true, + }) + } +} + +// GetFilesContentsEndpoint https://platform.openai.com/docs/api-reference/files/retrieve-contents +func GetFilesContentsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + file, err := getFileFromRequest(c) + if err != nil { + return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) + } + + fileContents, err := os.ReadFile(filepath.Join(o.UploadDir, file.Filename)) + if err != nil { + return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) + } + + return c.Send(fileContents) + } +} diff --git a/api/openai/files_test.go b/api/openai/files_test.go new file mode 100644 index 000000000000..535cde8ba564 --- /dev/null +++ b/api/openai/files_test.go @@ -0,0 +1,287 @@ +package openai + +import ( + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + + config "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/options" + utils2 "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/gofiber/fiber/v2" + "github.com/stretchr/testify/assert" + + "testing" +) + +type ListFiles struct { + Data []File + Object string +} + +func startUpApp() (app *fiber.App, option *options.Option, loader *config.ConfigLoader) { + // Preparing the mocked objects + loader = &config.ConfigLoader{} + + option = &options.Option{ + UploadLimitMB: 10, + UploadDir: "test_dir", + } + + _ = os.RemoveAll(option.UploadDir) + + app = fiber.New(fiber.Config{ + BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB. + }) + + // Create a Test Server + app.Post("/files", UploadFilesEndpoint(loader, option)) + app.Get("/files", ListFilesEndpoint(loader, option)) + app.Get("/files/:file_id", GetFilesEndpoint(loader, option)) + app.Delete("/files/:file_id", DeleteFilesEndpoint(loader, option)) + app.Get("/files/:file_id/content", GetFilesContentsEndpoint(loader, option)) + + return +} + +func TestUploadFileExceedSizeLimit(t *testing.T) { + // Preparing the mocked objects + loader := &config.ConfigLoader{} + + option := &options.Option{ + UploadLimitMB: 10, + UploadDir: "test_dir", + } + + _ = os.RemoveAll(option.UploadDir) + + app := fiber.New(fiber.Config{ + BodyLimit: 20 * 1024 * 1024, // sets the limit to 20MB. + }) + + // Create a Test Server + app.Post("/files", UploadFilesEndpoint(loader, option)) + app.Get("/files", ListFilesEndpoint(loader, option)) + app.Get("/files/:file_id", GetFilesEndpoint(loader, option)) + app.Delete("/files/:file_id", DeleteFilesEndpoint(loader, option)) + app.Get("/files/:file_id/content", GetFilesContentsEndpoint(loader, option)) + + t.Run("UploadFilesEndpoint file size exceeds limit", func(t *testing.T) { + resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 11, option) + assert.NoError(t, err) + + assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode) + assert.Contains(t, bodyToString(resp, t), "exceeds upload limit") + }) + t.Run("UploadFilesEndpoint purpose not defined", func(t *testing.T) { + resp, _ := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "", 5, option) + + assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode) + assert.Contains(t, bodyToString(resp, t), "Purpose is not defined") + }) + t.Run("UploadFilesEndpoint file already exists", func(t *testing.T) { + f1 := CallFilesUploadEndpointWithCleanup(t, app, "foo.txt", "file", "fine-tune", 5, option) + + resp, err := CallFilesUploadEndpoint(t, app, "foo.txt", "file", "fine-tune", 5, option) + fmt.Println(f1) + fmt.Printf("ERror: %v", err) + + assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode) + assert.Contains(t, bodyToString(resp, t), "File already exists") + }) + t.Run("UploadFilesEndpoint file uploaded successfully", func(t *testing.T) { + file := CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option) + + // Check if file exists in the disk + filePath := filepath.Join(option.UploadDir, utils2.SanitizeFileName("test.txt")) + _, err := os.Stat(filePath) + + assert.False(t, os.IsNotExist(err)) + assert.Equal(t, file.Bytes, 5242880) + assert.NotEmpty(t, file.CreatedAt) + assert.Equal(t, file.Filename, "test.txt") + assert.Equal(t, file.Purpose, "fine-tune") + }) + t.Run("ListFilesEndpoint without purpose parameter", func(t *testing.T) { + resp, err := CallListFilesEndpoint(t, app, "") + assert.NoError(t, err) + + assert.Equal(t, 200, resp.StatusCode) + + listFiles := responseToListFile(t, resp) + if len(listFiles.Data) != len(uploadedFiles) { + t.Errorf("Expected %v files, got %v files", len(uploadedFiles), len(listFiles.Data)) + } + }) + t.Run("ListFilesEndpoint with valid purpose parameter", func(t *testing.T) { + _ = CallFilesUploadEndpointWithCleanup(t, app, "test.txt", "file", "fine-tune", 5, option) + + resp, err := CallListFilesEndpoint(t, app, "fine-tune") + assert.NoError(t, err) + + listFiles := responseToListFile(t, resp) + if len(listFiles.Data) != 1 { + t.Errorf("Expected 1 file, got %v files", len(listFiles.Data)) + } + }) + t.Run("ListFilesEndpoint with invalid query parameter", func(t *testing.T) { + resp, err := CallListFilesEndpoint(t, app, "not-so-fine-tune") + assert.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + + listFiles := responseToListFile(t, resp) + + if len(listFiles.Data) != 0 { + t.Errorf("Expected 0 file, got %v files", len(listFiles.Data)) + } + }) + t.Run("GetFilesContentsEndpoint get file content", func(t *testing.T) { + req := httptest.NewRequest("GET", "/files", nil) + resp, _ := app.Test(req) + assert.Equal(t, 200, resp.StatusCode) + + var listFiles ListFiles + if err := json.Unmarshal(bodyToByteArray(resp, t), &listFiles); err != nil { + t.Errorf("Failed to decode response: %v", err) + return + } + + if len(listFiles.Data) != 0 { + t.Errorf("Expected 0 file, got %v files", len(listFiles.Data)) + } + }) +} + +func CallListFilesEndpoint(t *testing.T, app *fiber.App, purpose string) (*http.Response, error) { + var target string + if purpose != "" { + target = fmt.Sprintf("/files?purpose=%s", purpose) + } else { + target = "/files" + } + req := httptest.NewRequest("GET", target, nil) + return app.Test(req) +} + +func CallFilesContentEndpoint(t *testing.T, app *fiber.App, fileId string) (*http.Response, error) { + request := httptest.NewRequest("GET", "/files?file_id="+fileId, nil) + return app.Test(request) +} + +func CallFilesUploadEndpoint(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, o *options.Option) (*http.Response, error) { + // Create a file that exceeds the limit + file := createTestFile(t, fileName, fileSize, o) + + // Creating a new HTTP Request + body, writer := newMultipartFile(file.Name(), tag, purpose) + + req := httptest.NewRequest(http.MethodPost, "/files", body) + req.Header.Set(fiber.HeaderContentType, writer.FormDataContentType()) + return app.Test(req) +} + +func CallFilesUploadEndpointWithCleanup(t *testing.T, app *fiber.App, fileName, tag, purpose string, fileSize int, o *options.Option) File { + // Create a file that exceeds the limit + file := createTestFile(t, fileName, fileSize, o) + + // Creating a new HTTP Request + body, writer := newMultipartFile(file.Name(), tag, purpose) + + req := httptest.NewRequest(http.MethodPost, "/files", body) + req.Header.Set(fiber.HeaderContentType, writer.FormDataContentType()) + resp, err := app.Test(req) + assert.NoError(t, err) + f := responseToFile(t, resp) + + id := f.ID + t.Cleanup(func() { + _, err := CallFilesDeleteEndpoint(t, app, id) + assert.NoError(t, err) + }) + + return f + +} + +func CallFilesDeleteEndpoint(t *testing.T, app *fiber.App, fileId string) (*http.Response, error) { + target := fmt.Sprintf("/files/%s", fileId) + req := httptest.NewRequest(http.MethodDelete, target, nil) + return app.Test(req) +} + +// Helper to create multi-part file +func newMultipartFile(filePath, tag, purpose string) (*strings.Reader, *multipart.Writer) { + body := new(strings.Builder) + writer := multipart.NewWriter(body) + file, _ := os.Open(filePath) + defer file.Close() + part, _ := writer.CreateFormFile(tag, filepath.Base(filePath)) + io.Copy(part, file) + + if purpose != "" { + _ = writer.WriteField("purpose", purpose) + } + + writer.Close() + return strings.NewReader(body.String()), writer +} + +// Helper to create test files +func createTestFile(t *testing.T, name string, sizeMB int, option *options.Option) *os.File { + err := os.MkdirAll(option.UploadDir, 0755) + if err != nil { + + t.Fatalf("Error MKDIR: %v", err) + } + + file, _ := os.Create(name) + file.WriteString(strings.Repeat("a", sizeMB*1024*1024)) // sizeMB MB File + + t.Cleanup(func() { + os.Remove(name) + os.RemoveAll(option.UploadDir) + }) + return file +} + +func bodyToString(resp *http.Response, t *testing.T) string { + return string(bodyToByteArray(resp, t)) +} + +func bodyToByteArray(resp *http.Response, t *testing.T) []byte { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + return bodyBytes +} + +func responseToFile(t *testing.T, resp *http.Response) File { + var file File + responseToString := bodyToString(resp, t) + + err := json.NewDecoder(strings.NewReader(responseToString)).Decode(&file) + if err != nil { + t.Errorf("Failed to decode response: %s", err) + } + + return file +} + +func responseToListFile(t *testing.T, resp *http.Response) ListFiles { + var listFiles ListFiles + responseToString := bodyToString(resp, t) + + err := json.NewDecoder(strings.NewReader(responseToString)).Decode(&listFiles) + if err != nil { + fmt.Printf("Failed to decode response: %s", err) + } + + return listFiles +} diff --git a/api/openai/image.go b/api/openai/image.go index 07f028f013d1..2da6883eb193 100644 --- a/api/openai/image.go +++ b/api/openai/image.go @@ -13,12 +13,12 @@ import ( "strings" "time" - "github.com/go-skynet/LocalAI/api/schema" + "github.com/go-skynet/LocalAI/core/schema" "github.com/google/uuid" - "github.com/go-skynet/LocalAI/api/backend" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/core/backend" + config "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/options" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" diff --git a/api/openai/inference.go b/api/openai/inference.go index 816c960c3798..184688b27252 100644 --- a/api/openai/inference.go +++ b/api/openai/inference.go @@ -1,10 +1,10 @@ package openai import ( - "github.com/go-skynet/LocalAI/api/backend" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/api/schema" + "github.com/go-skynet/LocalAI/core/backend" + config "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/core/schema" model "github.com/go-skynet/LocalAI/pkg/model" ) diff --git a/api/openai/list.go b/api/openai/list.go index 8bc5bbe22bee..614d5c80e8b1 100644 --- a/api/openai/list.go +++ b/api/openai/list.go @@ -3,8 +3,8 @@ package openai import ( "regexp" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/schema" + config "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/schema" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" ) diff --git a/api/openai/request.go b/api/openai/request.go index 382a930e1c79..83c41d975f2b 100644 --- a/api/openai/request.go +++ b/api/openai/request.go @@ -9,10 +9,11 @@ import ( "net/http" "strings" - config "github.com/go-skynet/LocalAI/api/config" fiberContext "github.com/go-skynet/LocalAI/api/ctx" - options "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/api/schema" + config "github.com/go-skynet/LocalAI/core/config" + options "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/core/schema" + "github.com/go-skynet/LocalAI/pkg/grammar" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" @@ -136,6 +137,20 @@ func updateRequestConfig(config *config.Config, input *schema.OpenAIRequest) { } } + if len(input.Tools) > 0 { + for _, tool := range input.Tools { + input.Functions = append(input.Functions, tool.Function) + } + } + + if input.ToolsChoice != nil { + var toolChoice grammar.Tool + json.Unmarshal([]byte(input.ToolsChoice.(string)), &toolChoice) + input.FunctionCall = map[string]interface{}{ + "name": toolChoice.Function.Name, + } + } + // Decode each request's message content index := 0 for i, m := range input.Messages { diff --git a/api/openai/transcription.go b/api/openai/transcription.go index 668a20698176..c3fd7d5c83cc 100644 --- a/api/openai/transcription.go +++ b/api/openai/transcription.go @@ -8,9 +8,9 @@ import ( "path" "path/filepath" - "github.com/go-skynet/LocalAI/api/backend" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/core/backend" + config "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/options" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" diff --git a/backend/cpp/llama/CMakeLists.txt b/backend/cpp/llama/CMakeLists.txt index 8299705a36a0..031e49643fbb 100644 --- a/backend/cpp/llama/CMakeLists.txt +++ b/backend/cpp/llama/CMakeLists.txt @@ -2,16 +2,20 @@ ## XXX: In some versions of CMake clip wasn't being built before llama. ## This is an hack for now, but it should be fixed in the future. set(TARGET myclip) -add_library(${TARGET} clip.cpp clip.h) +add_library(${TARGET} clip.cpp clip.h llava.cpp llava.h) install(TARGETS ${TARGET} LIBRARY) -target_link_libraries(${TARGET} PRIVATE common ggml ${CMAKE_THREAD_LIBS_INIT}) +target_include_directories(myclip PUBLIC .) +target_include_directories(myclip PUBLIC ../..) +target_include_directories(myclip PUBLIC ../../common) +target_link_libraries(${TARGET} PRIVATE common ggml llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) if (NOT MSVC) target_compile_options(${TARGET} PRIVATE -Wno-cast-qual) # stb_image.h endif() +# END CLIP hack + set(TARGET grpc-server) -# END CLIP hack set(CMAKE_CXX_STANDARD 17) cmake_minimum_required(VERSION 3.15) set(TARGET grpc-server) diff --git a/backend/cpp/llama/Makefile b/backend/cpp/llama/Makefile index b050b62003a5..d6d8ae9039c2 100644 --- a/backend/cpp/llama/Makefile +++ b/backend/cpp/llama/Makefile @@ -45,6 +45,9 @@ llama.cpp/examples/grpc-server: ## XXX: In some versions of CMake clip wasn't being built before llama. ## This is an hack for now, but it should be fixed in the future. cp -rfv llama.cpp/examples/llava/clip.h llama.cpp/examples/grpc-server/clip.h + cp -rfv llama.cpp/examples/llava/llava.cpp llama.cpp/examples/grpc-server/llava.cpp + echo '#include "llama.h"' > llama.cpp/examples/grpc-server/llava.h + cat llama.cpp/examples/llava/llava.h >> llama.cpp/examples/grpc-server/llava.h cp -rfv llama.cpp/examples/llava/clip.cpp llama.cpp/examples/grpc-server/clip.cpp rebuild: diff --git a/backend/cpp/llama/grpc-server.cpp b/backend/cpp/llama/grpc-server.cpp index 954e472a786b..0066c16d533e 100644 --- a/backend/cpp/llama/grpc-server.cpp +++ b/backend/cpp/llama/grpc-server.cpp @@ -11,7 +11,8 @@ #include #include #include -#include "../llava/clip.h" +#include "clip.h" +#include "llava.h" #include "stb_image.h" #include "common.h" #include "json.hpp" @@ -32,6 +33,7 @@ #include #include #include +#include using grpc::Server; using grpc::ServerBuilder; @@ -51,9 +53,11 @@ struct server_params std::string hostname = "127.0.0.1"; std::vector api_keys; std::string public_path = "examples/server/public"; + std::string chat_template = ""; int32_t port = 8080; int32_t read_timeout = 600; int32_t write_timeout = 600; + bool slots_endpoint = true; }; bool server_verbose = false; @@ -172,6 +176,7 @@ struct llama_client_slot int32_t n_decoded = 0; int32_t n_remaining = -1; int32_t i_batch = -1; + int32_t n_predict = -1; int32_t num_prompt_tokens = 0; int32_t num_prompt_tokens_processed = 0; @@ -349,6 +354,7 @@ struct llama_server_context // slots / clients std::vector slots; + json default_generation_settings_for_props; llama_server_queue queue_tasks; llama_server_response queue_results; @@ -422,6 +428,7 @@ struct llama_server_context slot.id = i; slot.n_ctx = n_ctx_slot; + slot.n_predict = params.n_predict; LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot); @@ -445,11 +452,10 @@ struct llama_server_context slots.push_back(slot); } - batch = llama_batch_init(n_ctx, 0, params.n_parallel); + default_generation_settings_for_props = get_formated_generation(slots.front()); + default_generation_settings_for_props["seed"] = -1; - // empty system prompt - system_prompt = ""; - system_tokens.clear(); + batch = llama_batch_init(n_ctx, 0, params.n_parallel); } std::vector tokenize(const json & json_prompt, bool add_bos) const @@ -526,28 +532,40 @@ struct llama_server_context bool launch_slot_with_data(llama_client_slot* &slot, json data) { slot_params default_params; llama_sampling_params default_sparams; - - slot->params.stream = json_value(data, "stream", false); - slot->params.cache_prompt = json_value(data, "cache_prompt", false); - slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict); - slot->sparams.top_k = json_value(data, "top_k", default_sparams.top_k); - slot->sparams.top_p = json_value(data, "top_p", default_sparams.top_p); - slot->sparams.min_p = json_value(data, "min_p", default_sparams.min_p); - slot->sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); - slot->sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p); - slot->sparams.temp = json_value(data, "temperature", default_sparams.temp); - slot->sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); - slot->sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); - slot->sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); - slot->sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); - slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); - slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); - slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); - slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); - slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep); - slot->params.seed = json_value(data, "seed", default_params.seed); - slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar); - slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); + + slot->params.stream = json_value(data, "stream", false); + slot->params.cache_prompt = json_value(data, "cache_prompt", false); + slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict); + slot->sparams.top_k = json_value(data, "top_k", default_sparams.top_k); + slot->sparams.top_p = json_value(data, "top_p", default_sparams.top_p); + slot->sparams.min_p = json_value(data, "min_p", default_sparams.min_p); + slot->sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); + slot->sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p); + slot->sparams.temp = json_value(data, "temperature", default_sparams.temp); + slot->sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); + slot->sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); + slot->sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); + slot->sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); + slot->sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); + slot->sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); + slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); + slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); + slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); + slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); + slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep); + slot->params.seed = json_value(data, "seed", default_params.seed); + slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar); + slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); + slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); + + if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) { + // Might be better to reject the request with a 400 ? + LOG_WARNING("Max tokens to predict exceeds server configuration", { + {"params.n_predict", slot->params.n_predict}, + {"slot.n_predict", slot->n_predict}, + }); + slot->params.n_predict = slot->n_predict; + } // infill if (data.count("input_prefix") != 0) @@ -626,18 +644,36 @@ struct llama_server_context const int n_vocab = llama_n_vocab(model); for (const auto &el : *logit_bias) { - if (el.is_array() && el.size() == 2 && el[0].is_number_integer()) + if (el.is_array() && el.size() == 2) { - llama_token tok = el[0].get(); - if (tok >= 0 && tok < n_vocab) + float bias; + if (el[1].is_number()) + { + bias = el[1].get(); + } + else if (el[1].is_boolean() && !el[1].get()) + { + bias = -INFINITY; + } + else { - if (el[1].is_number()) + continue; + } + + if (el[0].is_number_integer()) + { + llama_token tok = el[0].get(); + if (tok >= 0 && tok < n_vocab) { - slot->sparams.logit_bias[tok] = el[1].get(); + slot->sparams.logit_bias[tok] = bias; } - else if (el[1].is_boolean() && !el[1].get()) + } + else if (el[0].is_string()) + { + auto toks = llama_tokenize(model, el[0].get(), false); + for (auto tok : toks) { - slot->sparams.logit_bias[tok] = -INFINITY; + slot->sparams.logit_bias[tok] = bias; } } } @@ -658,6 +694,24 @@ struct llama_server_context } } + const auto &samplers_sequence = data.find("samplers"); + if (samplers_sequence != data.end() && samplers_sequence->is_array()) + { + std::vector sampler_names; + for (const auto &sampler_name : *samplers_sequence) + { + if (sampler_name.is_string()) + { + sampler_names.emplace_back(sampler_name); + } + } + slot->sparams.samplers_sequence = sampler_types_from_names(sampler_names, false); + } + else + { + slot->sparams.samplers_sequence = default_sparams.samplers_sequence; + } + if (multimodal) { const auto &images_data = data.find("image_data"); @@ -747,27 +801,30 @@ struct llama_server_context } void update_system_prompt() { - system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token); + kv_cache_clear(); + system_tokens.clear(); - llama_batch_clear(batch); + if (!system_prompt.empty()) { + system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token); - kv_cache_clear(); + llama_batch_clear(batch); - for (int i = 0; i < (int) system_tokens.size(); ++i) - { - llama_batch_add(batch, system_tokens[i], i, { 0 }, false); - } + for (int i = 0; i < (int)system_tokens.size(); ++i) + { + llama_batch_add(batch, system_tokens[i], i, { 0 }, false); + } - if (llama_decode(ctx, batch) != 0) - { - LOG_TEE("%s: llama_decode() failed\n", __func__); - return; - } + if (llama_decode(ctx, batch) != 0) + { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return; + } - // assign the system KV cache to all parallel sequences - for (int32_t i = 1; i < params.n_parallel; ++i) - { - llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size()); + // assign the system KV cache to all parallel sequences + for (int32_t i = 1; i < params.n_parallel; ++i) + { + llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size()); + } } LOG_TEE("system prompt updated\n"); @@ -789,10 +846,8 @@ struct llama_server_context name_user = sys_props.value("anti_prompt", ""); name_assistant = sys_props.value("assistant_name", ""); - if (slots.size() > 0) - { - notify_system_prompt_changed(); - } + + notify_system_prompt_changed(); } static size_t find_stopping_strings(const std::string &text, const size_t last_token_size, @@ -950,28 +1005,12 @@ struct llama_server_context { continue; } - clip_image_f32 * img_res = clip_image_f32_init(); - if (!clip_image_preprocess(clp_ctx, img.img_data, img_res, /*pad2square =*/ true)) - { + + if (!llava_image_embed_make_with_clip_img(clp_ctx, params.n_threads, img.img_data, &img.image_embedding, &img.image_tokens)) { LOG_TEE("Error processing the given image"); - clip_free(clp_ctx); return false; } - img.image_tokens = clip_n_patches(clp_ctx); - img.image_embedding = (float *)malloc(clip_embd_nbytes(clp_ctx)); - if (!img.image_embedding) - { - LOG_TEE("Unable to allocate memory for image embeddings\n"); - clip_free(clp_ctx); - return false; - } - LOG_TEE("slot %i - encoding image [id: %i]\n", slot.id, img.id); - if (!clip_image_encode(clp_ctx, params.n_threads, img_res, img.image_embedding)) - { - LOG_TEE("Unable to encode image\n"); - return false; - } - clip_image_f32_free(img_res); + img.request_encode_image = false; } @@ -990,21 +1029,25 @@ struct llama_server_context queue_results.send(res); } - json get_model_props() - { - return get_formated_generation(slots[0]); - } - json get_formated_generation(llama_client_slot &slot) { const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model)); const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); + std::vector samplers_sequence; + for (const auto &sampler_type : slot.sparams.samplers_sequence) + { + samplers_sequence.emplace_back(sampler_type_to_name_string(sampler_type)); + } + return json { {"n_ctx", slot.n_ctx}, + {"n_predict", slot.n_predict}, {"model", params.model_alias}, {"seed", slot.params.seed}, {"temperature", slot.sparams.temp}, + {"dynatemp_range", slot.sparams.dynatemp_range}, + {"dynatemp_exponent", slot.sparams.dynatemp_exponent}, {"top_k", slot.sparams.top_k}, {"top_p", slot.sparams.top_p}, {"min_p", slot.sparams.min_p}, @@ -1027,7 +1070,9 @@ struct llama_server_context {"stream", slot.params.stream}, {"logit_bias", slot.sparams.logit_bias}, {"n_probs", slot.sparams.n_probs}, + {"min_keep", slot.sparams.min_keep}, {"grammar", slot.sparams.grammar}, + {"samplers", samplers_sequence} }; } @@ -1166,13 +1211,30 @@ struct llama_server_context task.multitask_id = multitask_id; // when a completion task's prompt array is not a singleton, we split it into multiple requests - if (task.data.count("prompt") && task.data.at("prompt").size() > 1) - { - split_multiprompt_task(task_id, task); - } - // otherwise, it's a single-prompt task, we actually queue it - queue_tasks.post(task); + // if there's numbers in the prompt array it will be treated as an array of tokens + if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) { + bool numbers = false; + for (const auto& e : task.data.at("prompt")) { + if (e.is_number()) { + numbers = true; + break; + } + } + + // NOTE: split_multiprompt_task() does not handle a mix of strings and numbers, + // it will completely stall the server. I don't know where the bug for this is. + // + // if there are numbers, it needs to be treated like a single prompt, + // queue_tasks handles a mix of strings and numbers just fine. + if (numbers) { + queue_tasks.post(task); + } else { + split_multiprompt_task(task_id, task); + } + } else { + queue_tasks.post(task); + } } // for multiple images processing @@ -1254,7 +1316,10 @@ struct llama_server_context void split_multiprompt_task(int multitask_id, task_server& multiprompt_task) { int prompt_count = multiprompt_task.data.at("prompt").size(); - assert(prompt_count > 1); + if (prompt_count <= 1) { + send_error(multiprompt_task, "error while handling multiple prompts"); + return; + } // generate all the ID for subtask std::vector subtask_ids(prompt_count); @@ -1566,10 +1631,6 @@ struct llama_server_context LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); } - LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past); - - llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1); - slot.cache_tokens = prompt_tokens; if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0) @@ -1583,6 +1644,10 @@ struct llama_server_context } } + LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past); + + llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1); + LOG_VERBOSE("prompt ingested", { {"n_past", slot.n_past}, {"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)}, @@ -1819,6 +1884,9 @@ static void append_to_generated_text_from_generated_token_probs(llama_server_con } } +std::function shutdown_handler; +inline void signal_handler(int signal) { shutdown_handler(signal); } + ///////////////////////////////// //////////////////////////////// //////// LOCALAI code starts below here @@ -2089,7 +2157,8 @@ class BackendServiceImpl final : public backend::Backend::Service { gpt_params params; params_parse(request, params); - llama_backend_init(params.numa); + llama_backend_init(); + llama_numa_init(params.numa); // load the model if (!llama.load_model(params)) diff --git a/backend/go/transcribe/transcript.go b/backend/go/transcribe/transcript.go index ebd43eca6b84..dc331caea988 100644 --- a/backend/go/transcribe/transcript.go +++ b/backend/go/transcribe/transcript.go @@ -8,7 +8,7 @@ import ( "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" "github.com/go-audio/wav" - "github.com/go-skynet/LocalAI/api/schema" + "github.com/go-skynet/LocalAI/core/schema" ) func sh(c string) (string, error) { diff --git a/backend/go/transcribe/whisper.go b/backend/go/transcribe/whisper.go index a033afb0cdbc..ac93be01195b 100644 --- a/backend/go/transcribe/whisper.go +++ b/backend/go/transcribe/whisper.go @@ -4,7 +4,7 @@ package main // It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) import ( "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" - "github.com/go-skynet/LocalAI/api/schema" + "github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" ) diff --git a/backend/python/common-env/transformers/Makefile b/backend/python/common-env/transformers/Makefile index 47a5ba25db91..1cd71ab177d3 100644 --- a/backend/python/common-env/transformers/Makefile +++ b/backend/python/common-env/transformers/Makefile @@ -4,6 +4,10 @@ ifeq ($(BUILD_TYPE), cublas) CONDA_ENV_PATH = "transformers-nvidia.yml" endif +ifeq ($(BUILD_TYPE), hipblas) + CONDA_ENV_PATH = "transformers-rocm.yml" +endif + .PHONY: transformers transformers: @echo "Installing $(CONDA_ENV_PATH)..." diff --git a/backend/python/common-env/transformers/transformers-nvidia.yml b/backend/python/common-env/transformers/transformers-nvidia.yml index 621335590cbd..d5fe07b4d1d3 100644 --- a/backend/python/common-env/transformers/transformers-nvidia.yml +++ b/backend/python/common-env/transformers/transformers-nvidia.yml @@ -33,6 +33,7 @@ dependencies: - boto3==1.28.61 - botocore==1.31.61 - certifi==2023.7.22 + - TTS==0.22.0 - charset-normalizer==3.3.0 - datasets==2.14.5 - sentence-transformers==2.2.2 diff --git a/backend/python/common-env/transformers/transformers-rocm.yml b/backend/python/common-env/transformers/transformers-rocm.yml new file mode 100644 index 000000000000..1f5d223623c3 --- /dev/null +++ b/backend/python/common-env/transformers/transformers-rocm.yml @@ -0,0 +1,109 @@ +name: transformers +channels: + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2023.08.22=h06a4308_0 + - ld_impl_linux-64=2.38=h1181459_1 + - libffi=3.4.4=h6a678d5_0 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libstdcxx-ng=11.2.0=h1234567_1 + - libuuid=1.41.5=h5eee18b_0 + - ncurses=6.4=h6a678d5_0 + - openssl=3.0.11=h7f8727e_2 + - pip=23.2.1=py311h06a4308_0 + - python=3.11.5=h955ad1f_0 + - readline=8.2=h5eee18b_0 + - setuptools=68.0.0=py311h06a4308_0 + - sqlite=3.41.2=h5eee18b_0 + - tk=8.6.12=h1ccaba5_0 + - wheel=0.41.2=py311h06a4308_0 + - xz=5.4.2=h5eee18b_0 + - zlib=1.2.13=h5eee18b_0 + - pip: + - --pre + - --extra-index-url https://download.pytorch.org/whl/nightly/ + - accelerate==0.23.0 + - aiohttp==3.8.5 + - aiosignal==1.3.1 + - async-timeout==4.0.3 + - attrs==23.1.0 + - bark==0.1.5 + - boto3==1.28.61 + - botocore==1.31.61 + - certifi==2023.7.22 + - TTS==0.22.0 + - charset-normalizer==3.3.0 + - datasets==2.14.5 + - sentence-transformers==2.2.2 + - sentencepiece==0.1.99 + - dill==0.3.7 + - einops==0.7.0 + - encodec==0.1.1 + - filelock==3.12.4 + - frozenlist==1.4.0 + - fsspec==2023.6.0 + - funcy==2.0 + - grpcio==1.59.0 + - huggingface-hub + - idna==3.4 + - jinja2==3.1.2 + - jmespath==1.0.1 + - markupsafe==2.1.3 + - mpmath==1.3.0 + - multidict==6.0.4 + - multiprocess==0.70.15 + - networkx + - numpy==1.26.0 + - packaging==23.2 + - pandas + - peft==0.5.0 + - protobuf==4.24.4 + - psutil==5.9.5 + - pyarrow==13.0.0 + - python-dateutil==2.8.2 + - pytz==2023.3.post1 + - pyyaml==6.0.1 + - regex==2023.10.3 + - requests==2.31.0 + - rouge==1.0.1 + - s3transfer==0.7.0 + - safetensors==0.3.3 + - scipy==1.11.3 + - six==1.16.0 + - sympy==1.12 + - tokenizers + - torch + - torchaudio + - tqdm==4.66.1 + - triton==2.1.0 + - typing-extensions==4.8.0 + - tzdata==2023.3 + - auto-gptq==0.6.0 + - urllib3==1.26.17 + - xxhash==3.4.1 + - yarl==1.9.2 + - soundfile + - langid + - wget + - unidecode + - pyopenjtalk-prebuilt + - pypinyin + - inflect + - cn2an + - jieba + - eng_to_ipa + - openai-whisper + - matplotlib + - gradio==3.41.2 + - nltk + - sudachipy + - sudachidict_core + - vocos + - vllm==0.2.7 + - transformers>=4.36.0 # Required for Mixtral. + - xformers==0.0.23.post1 +prefix: /opt/conda/envs/transformers diff --git a/backend/python/diffusers/Makefile b/backend/python/diffusers/Makefile index 4ec03c710359..70a62b60daa9 100644 --- a/backend/python/diffusers/Makefile +++ b/backend/python/diffusers/Makefile @@ -1,4 +1,8 @@ -CONDA_ENV_PATH = "diffusers.yml" +export CONDA_ENV_PATH = "diffusers.yml" + +ifeq ($(BUILD_TYPE), hipblas) +export CONDA_ENV_PATH = "diffusers-rocm.yml" +endif .PHONY: diffusers diffusers: @@ -12,4 +16,4 @@ run: @echo "Diffusers run." test: - bash test.sh \ No newline at end of file + bash test.sh diff --git a/backend/python/diffusers/diffusers-rocm.yml b/backend/python/diffusers/diffusers-rocm.yml new file mode 100644 index 000000000000..f261701dbb36 --- /dev/null +++ b/backend/python/diffusers/diffusers-rocm.yml @@ -0,0 +1,64 @@ +name: diffusers +channels: + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2023.08.22=h06a4308_0 + - ld_impl_linux-64=2.38=h1181459_1 + - libffi=3.4.4=h6a678d5_0 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libstdcxx-ng=11.2.0=h1234567_1 + - libuuid=1.41.5=h5eee18b_0 + - ncurses=6.4=h6a678d5_0 + - openssl=3.0.11=h7f8727e_2 + - pip=23.2.1=py311h06a4308_0 + - python=3.11.5=h955ad1f_0 + - readline=8.2=h5eee18b_0 + - setuptools=68.0.0=py311h06a4308_0 + - sqlite=3.41.2=h5eee18b_0 + - tk=8.6.12=h1ccaba5_0 + - tzdata=2023c=h04d1e81_0 + - wheel=0.41.2=py311h06a4308_0 + - xz=5.4.2=h5eee18b_0 + - zlib=1.2.13=h5eee18b_0 + - pip: + - --pre + - --extra-index-url https://download.pytorch.org/whl/nightly/ + - accelerate>=0.11.0 + - certifi==2023.7.22 + - charset-normalizer==3.3.0 + - compel==2.0.2 + - diffusers==0.24.0 + - filelock==3.12.4 + - fsspec==2023.9.2 + - grpcio==1.59.0 + - huggingface-hub>=0.19.4 + - idna==3.4 + - importlib-metadata==6.8.0 + - jinja2==3.1.2 + - markupsafe==2.1.3 + - mpmath==1.3.0 + - networkx==3.1 + - numpy==1.26.0 + - omegaconf + - packaging==23.2 + - pillow==10.0.1 + - protobuf==4.24.4 + - psutil==5.9.5 + - pyparsing==3.1.1 + - pyyaml==6.0.1 + - regex==2023.10.3 + - requests==2.31.0 + - safetensors==0.4.0 + - sympy==1.12 + - tqdm==4.66.1 + - transformers>=4.25.1 + - triton==2.1.0 + - typing-extensions==4.8.0 + - urllib3==2.0.6 + - zipp==3.17.0 + - torch +prefix: /opt/conda/envs/diffusers diff --git a/backend/python/diffusers/diffusers.yml b/backend/python/diffusers/diffusers.yml index a37f41d9c439..b1a7d9f971b2 100644 --- a/backend/python/diffusers/diffusers.yml +++ b/backend/python/diffusers/diffusers.yml @@ -71,4 +71,4 @@ dependencies: - typing-extensions==4.8.0 - urllib3==2.0.6 - zipp==3.17.0 -prefix: /opt/conda/envs/diffusers \ No newline at end of file +prefix: /opt/conda/envs/diffusers diff --git a/api/backend/embeddings.go b/core/backend/embeddings.go similarity index 95% rename from api/backend/embeddings.go rename to core/backend/embeddings.go index 0cf15fea32cf..d8b89e124206 100644 --- a/api/backend/embeddings.go +++ b/core/backend/embeddings.go @@ -3,8 +3,8 @@ package backend import ( "fmt" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" + config "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/options" "github.com/go-skynet/LocalAI/pkg/grpc" model "github.com/go-skynet/LocalAI/pkg/model" ) diff --git a/api/backend/image.go b/core/backend/image.go similarity index 94% rename from api/backend/image.go rename to core/backend/image.go index 6183269fd3ca..12ea57ceb7dd 100644 --- a/api/backend/image.go +++ b/core/backend/image.go @@ -1,8 +1,8 @@ package backend import ( - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" + config "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/options" "github.com/go-skynet/LocalAI/pkg/grpc/proto" model "github.com/go-skynet/LocalAI/pkg/model" ) diff --git a/api/backend/llm.go b/core/backend/llm.go similarity index 97% rename from api/backend/llm.go rename to core/backend/llm.go index 9e202c53c53b..d1081ad65fc1 100644 --- a/api/backend/llm.go +++ b/core/backend/llm.go @@ -8,8 +8,8 @@ import ( "sync" "unicode/utf8" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" + config "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/options" "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/grpc" model "github.com/go-skynet/LocalAI/pkg/model" diff --git a/api/backend/options.go b/core/backend/options.go similarity index 97% rename from api/backend/options.go rename to core/backend/options.go index 38f560688f79..9710ac175d3e 100644 --- a/api/backend/options.go +++ b/core/backend/options.go @@ -7,8 +7,8 @@ import ( pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" model "github.com/go-skynet/LocalAI/pkg/model" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" + config "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/options" ) func modelOpts(c config.Config, o *options.Option, opts []model.Option) []model.Option { diff --git a/api/backend/transcript.go b/core/backend/transcript.go similarity index 86% rename from api/backend/transcript.go rename to core/backend/transcript.go index 77427839992a..1cbaf8201669 100644 --- a/api/backend/transcript.go +++ b/core/backend/transcript.go @@ -4,10 +4,10 @@ import ( "context" "fmt" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/schema" + config "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/schema" - "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/core/options" "github.com/go-skynet/LocalAI/pkg/grpc/proto" model "github.com/go-skynet/LocalAI/pkg/model" ) diff --git a/api/backend/tts.go b/core/backend/tts.go similarity index 90% rename from api/backend/tts.go rename to core/backend/tts.go index 6e5ffcc0c1b1..a9d7153f9348 100644 --- a/api/backend/tts.go +++ b/core/backend/tts.go @@ -6,9 +6,8 @@ import ( "os" "path/filepath" - api_config "github.com/go-skynet/LocalAI/api/config" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" + config "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/options" "github.com/go-skynet/LocalAI/pkg/grpc/proto" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/utils" @@ -38,7 +37,7 @@ func ModelTTS(backend, text, modelFile string, loader *model.ModelLoader, o *opt grpcOpts := gRPCModelOpts(c) - opts := modelOpts(api_config.Config{}, o, []model.Option{ + opts := modelOpts(config.Config{}, o, []model.Option{ model.WithBackendString(bb), model.WithModel(modelFile), model.WithContext(o.Context), diff --git a/api/config/config.go b/core/config/config.go similarity index 99% rename from api/config/config.go rename to core/config/config.go index 48d1b791fe80..af203ecccf14 100644 --- a/api/config/config.go +++ b/core/config/config.go @@ -1,4 +1,4 @@ -package api_config +package config import ( "errors" @@ -148,6 +148,7 @@ type Functions struct { DisableNoAction bool `yaml:"disable_no_action"` NoActionFunctionName string `yaml:"no_action_function_name"` NoActionDescriptionName string `yaml:"no_action_description_name"` + ParallelCalls bool `yaml:"parallel_calls"` } type TemplateConfig struct { diff --git a/api/config/config_test.go b/core/config/config_test.go similarity index 93% rename from api/config/config_test.go rename to core/config/config_test.go index 4b00d587eff2..d1e92d5cc41e 100644 --- a/api/config/config_test.go +++ b/core/config/config_test.go @@ -1,10 +1,10 @@ -package api_config_test +package config_test import ( "os" - . "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" + . "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/options" "github.com/go-skynet/LocalAI/pkg/model" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" diff --git a/api/config/prediction.go b/core/config/prediction.go similarity index 99% rename from api/config/prediction.go rename to core/config/prediction.go index d2fbb1fa9687..dccb4dfb9f6f 100644 --- a/api/config/prediction.go +++ b/core/config/prediction.go @@ -1,4 +1,4 @@ -package api_config +package config type PredictionOptions struct { diff --git a/api/api.go b/core/http/api.go similarity index 89% rename from api/api.go rename to core/http/api.go index 7ec95f1b63a6..7d228152409e 100644 --- a/api/api.go +++ b/core/http/api.go @@ -1,4 +1,4 @@ -package api +package http import ( "encoding/json" @@ -7,11 +7,11 @@ import ( "os" "strings" - config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/localai" "github.com/go-skynet/LocalAI/api/openai" - "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/api/schema" + config "github.com/go-skynet/LocalAI/core/config" + "github.com/go-skynet/LocalAI/core/options" + "github.com/go-skynet/LocalAI/core/schema" "github.com/go-skynet/LocalAI/internal" "github.com/go-skynet/LocalAI/metrics" "github.com/go-skynet/LocalAI/pkg/assets" @@ -146,7 +146,11 @@ func App(opts ...options.AppOption) (*fiber.App, error) { } // Default middleware config - app.Use(recover.New()) + + if !options.Debug { + app.Use(recover.New()) + } + if options.Metrics != nil { app.Use(metrics.APIMiddleware(options.Metrics)) } @@ -219,8 +223,12 @@ func App(opts ...options.AppOption) (*fiber.App, error) { // Make sure directories exists os.MkdirAll(options.ImageDir, 0755) os.MkdirAll(options.AudioDir, 0755) + os.MkdirAll(options.UploadDir, 0755) os.MkdirAll(options.Loader.ModelPath, 0755) + // Load upload json + openai.LoadUploadConfig(options.UploadDir) + modelGalleryService := localai.CreateModelGalleryService(options.Galleries, options.Loader.ModelPath, galleryService) app.Post("/models/apply", auth, modelGalleryService.ApplyModelGalleryEndpoint()) app.Get("/models/available", auth, modelGalleryService.ListModelFromGalleryEndpoint()) @@ -240,6 +248,18 @@ func App(opts ...options.AppOption) (*fiber.App, error) { app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options)) app.Post("/edits", auth, openai.EditEndpoint(cl, options)) + // files + app.Post("/v1/files", auth, openai.UploadFilesEndpoint(cl, options)) + app.Post("/files", auth, openai.UploadFilesEndpoint(cl, options)) + app.Get("/v1/files", auth, openai.ListFilesEndpoint(cl, options)) + app.Get("/files", auth, openai.ListFilesEndpoint(cl, options)) + app.Get("/v1/files/:file_id", auth, openai.GetFilesEndpoint(cl, options)) + app.Get("/files/:file_id", auth, openai.GetFilesEndpoint(cl, options)) + app.Delete("/v1/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, options)) + app.Delete("/files/:file_id", auth, openai.DeleteFilesEndpoint(cl, options)) + app.Get("/v1/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, options)) + app.Get("/files/:file_id/content", auth, openai.GetFilesContentsEndpoint(cl, options)) + // completion app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options)) app.Post("/completions", auth, openai.CompletionEndpoint(cl, options)) diff --git a/api/api_test.go b/core/http/api_test.go similarity index 99% rename from api/api_test.go rename to core/http/api_test.go index 04d2d6fec02e..9068b393e1a9 100644 --- a/api/api_test.go +++ b/core/http/api_test.go @@ -1,4 +1,4 @@ -package api_test +package http_test import ( "bytes" @@ -13,8 +13,8 @@ import ( "path/filepath" "runtime" - . "github.com/go-skynet/LocalAI/api" - "github.com/go-skynet/LocalAI/api/options" + . "github.com/go-skynet/LocalAI/core/http" + "github.com/go-skynet/LocalAI/core/options" "github.com/go-skynet/LocalAI/metrics" "github.com/go-skynet/LocalAI/pkg/downloader" "github.com/go-skynet/LocalAI/pkg/gallery" diff --git a/api/apt_suite_test.go b/core/http/apt_suite_test.go similarity index 90% rename from api/apt_suite_test.go rename to core/http/apt_suite_test.go index e3c15c048b14..0269a97321df 100644 --- a/api/apt_suite_test.go +++ b/core/http/apt_suite_test.go @@ -1,4 +1,4 @@ -package api_test +package http_test import ( "testing" diff --git a/api/options/options.go b/core/options/options.go similarity index 97% rename from api/options/options.go rename to core/options/options.go index 8c066584038b..72aea1a32932 100644 --- a/api/options/options.go +++ b/core/options/options.go @@ -21,6 +21,7 @@ type Option struct { Debug, DisableMessage bool ImageDir string AudioDir string + UploadDir string CORS bool PreloadJSONModels string PreloadModelsFromPath string @@ -249,6 +250,12 @@ func WithImageDir(imageDir string) AppOption { } } +func WithUploadDir(uploadDir string) AppOption { + return func(o *Option) { + o.UploadDir = uploadDir + } +} + func WithApiKeys(apiKeys []string) AppOption { return func(o *Option) { o.ApiKeys = apiKeys diff --git a/api/schema/openai.go b/core/schema/openai.go similarity index 84% rename from api/schema/openai.go rename to core/schema/openai.go index 6355ff63d5e2..23abd7b76c46 100644 --- a/api/schema/openai.go +++ b/core/schema/openai.go @@ -3,7 +3,7 @@ package schema import ( "context" - config "github.com/go-skynet/LocalAI/api/config" + config "github.com/go-skynet/LocalAI/core/config" "github.com/go-skynet/LocalAI/pkg/grammar" ) @@ -68,6 +68,10 @@ type ContentURL struct { type Message struct { // The message role Role string `json:"role,omitempty" yaml:"role"` + + // The message name (used for tools calls) + Name string `json:"name,omitempty" yaml:"name"` + // The message content Content interface{} `json:"content" yaml:"content"` @@ -76,6 +80,20 @@ type Message struct { // A result of a function call FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"` + + ToolCalls []ToolCall `json:"tool_calls,omitempty" yaml:"tool_call,omitempty"` +} + +type ToolCall struct { + Index int `json:"index"` + ID string `json:"id"` + Type string `json:"type"` + FunctionCall FunctionCall `json:"function"` +} + +type FunctionCall struct { + Name string `json:"name,omitempty"` + Arguments string `json:"arguments"` } type OpenAIModel struct { @@ -117,6 +135,9 @@ type OpenAIRequest struct { Functions []grammar.Function `json:"functions" yaml:"functions"` FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object + Tools []grammar.Tool `json:"tools,omitempty" yaml:"tools"` + ToolsChoice interface{} `json:"tool_choice,omitempty" yaml:"tool_choice"` + Stream bool `json:"stream"` // Image (not supported by OpenAI) diff --git a/api/schema/whisper.go b/core/schema/whisper.go similarity index 100% rename from api/schema/whisper.go rename to core/schema/whisper.go diff --git a/docs/data/version.json b/docs/data/version.json index f5a5a75c2e50..890f6c35949f 100644 --- a/docs/data/version.json +++ b/docs/data/version.json @@ -1,3 +1,3 @@ { - "version": "v2.8.0" + "version": "v2.8.2" } diff --git a/embedded/models/mistral-openorca.yaml b/embedded/models/mistral-openorca.yaml index fbab4e39ad1f..f40d854f72f1 100644 --- a/embedded/models/mistral-openorca.yaml +++ b/embedded/models/mistral-openorca.yaml @@ -11,20 +11,18 @@ template: <|im_start|>{{if eq .RoleName "assistant"}}assistant{{else if eq .RoleName "system"}}system{{else if eq .RoleName "user"}}user{{end}} {{if .Content}}{{.Content}}{{end}} <|im_end|> - chat: | {{.Input}} <|im_start|>assistant - completion: | {{.Input}} context_size: 4096 f16: true stopwords: - <|im_end|> - +- usage: | curl http://localhost:8080/v1/chat/completions -H "Content-Type: application/json" -d '{ "model": "mistral-openorca", "messages": [{"role": "user", "content": "How are you doing?", "temperature": 0.1}] - }' \ No newline at end of file + }' diff --git a/examples/configurations/phi-2.yaml b/examples/configurations/phi-2.yaml index 8f1938669602..cac1e9da9f5d 100644 --- a/examples/configurations/phi-2.yaml +++ b/examples/configurations/phi-2.yaml @@ -12,7 +12,7 @@ parameters: top_p: 0.95 seed: -1 template: - chat: &template | + chat: &template |- Instruct: {{.Input}} Output: completion: *template diff --git a/go.mod b/go.mod index 250a2361796f..bbd787b50f5a 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,6 @@ require ( github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e github.com/go-audio/wav v1.1.0 github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1 - github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230714203132-ffb09d7dd71e github.com/go-skynet/go-llama.cpp v0.0.0-20231009155254-aeba71ee8428 github.com/gofiber/fiber/v2 v2.50.0 github.com/google/uuid v1.3.1 @@ -28,6 +27,7 @@ require ( github.com/rs/zerolog v1.31.0 github.com/sashabaranov/go-openai v1.16.0 github.com/schollz/progressbar/v3 v3.13.1 + github.com/stretchr/testify v1.8.4 github.com/tmc/langchaingo v0.0.0-20231019140956-c636b3da7701 github.com/urfave/cli/v2 v2.25.7 github.com/valyala/fasthttp v1.50.0 @@ -55,6 +55,7 @@ require ( require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dlclark/regexp2 v1.8.1 // indirect github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -68,6 +69,7 @@ require ( github.com/nwaples/rardecode v1.1.0 // indirect github.com/pierrec/lz4/v4 v4.1.2 // indirect github.com/pkoukk/tiktoken-go v0.1.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 // indirect github.com/prometheus/common v0.44.0 // indirect github.com/prometheus/procfs v0.11.1 // indirect diff --git a/go.sum b/go.sum index fc00bf6e2ae6..20dfbfb497d1 100644 --- a/go.sum +++ b/go.sum @@ -43,8 +43,6 @@ github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1 h1:yXvc7QfGtoZ51tUW/YVjoTwAfh8HG88XU7UOrbNlz5Y= github.com/go-skynet/go-bert.cpp v0.0.0-20230716133540-6abe312cded1/go.mod h1:fYjkCDRzC+oRLHSjQoajmYK6AmeJnmEanV27CClAcDc= -github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230714203132-ffb09d7dd71e h1:4reMY29i1eOZaRaSTMPNyXI7X8RMNxCTfDDBXYzrbr0= -github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230714203132-ffb09d7dd71e/go.mod h1:31j1odgFXP8hDSUVfH0zErKI5aYVP18ddYnPkwCso2A= github.com/go-skynet/go-llama.cpp v0.0.0-20231009155254-aeba71ee8428 h1:WYjkXL0Nw7dN2uDBMVCWQ8xLavrIhjF/DLczuh5L9TY= github.com/go-skynet/go-llama.cpp v0.0.0-20231009155254-aeba71ee8428/go.mod h1:iub0ugfTnflE3rcIuqV2pQSo15nEw3GLW/utm5gyERo= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= diff --git a/main.go b/main.go index edf703287851..7e4262ee57bd 100644 --- a/main.go +++ b/main.go @@ -12,10 +12,10 @@ import ( "syscall" "time" - api "github.com/go-skynet/LocalAI/api" - "github.com/go-skynet/LocalAI/api/backend" - config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/core/backend" + config "github.com/go-skynet/LocalAI/core/config" + api "github.com/go-skynet/LocalAI/core/http" + "github.com/go-skynet/LocalAI/core/options" "github.com/go-skynet/LocalAI/internal" "github.com/go-skynet/LocalAI/metrics" "github.com/go-skynet/LocalAI/pkg/gallery" @@ -142,6 +142,12 @@ func main() { EnvVars: []string{"AUDIO_PATH"}, Value: "/tmp/generated/audio", }, + &cli.StringFlag{ + Name: "upload-path", + Usage: "Path to store uploads from files api", + EnvVars: []string{"UPLOAD_PATH"}, + Value: "/tmp/localai/upload", + }, &cli.StringFlag{ Name: "backend-assets-path", Usage: "Path used to extract libraries that are required by some of the backends in runtime.", @@ -227,6 +233,7 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit options.WithDebug(ctx.Bool("debug")), options.WithImageDir(ctx.String("image-path")), options.WithAudioDir(ctx.String("audio-path")), + options.WithUploadDir(ctx.String("upload-path")), options.WithF16(ctx.Bool("f16")), options.WithStringGalleries(ctx.String("galleries")), options.WithModelLibraryURL(ctx.String("remote-library")), diff --git a/pkg/grammar/functions.go b/pkg/grammar/functions.go index ef56662b7b93..1038f5e6f147 100644 --- a/pkg/grammar/functions.go +++ b/pkg/grammar/functions.go @@ -11,6 +11,12 @@ type Function struct { } type Functions []Function +type Tool struct { + Type string `json:"type"` + Function Function `json:"function,omitempty"` +} +type Tools []Tool + func (f Functions) ToJSONStructure() JSONFunctionStructure { js := JSONFunctionStructure{} for _, function := range f { diff --git a/pkg/grammar/json_schema.go b/pkg/grammar/json_schema.go index 40d7f4e60cc6..76f9778f5b7f 100644 --- a/pkg/grammar/json_schema.go +++ b/pkg/grammar/json_schema.go @@ -105,11 +105,28 @@ func (sc *JSONSchemaConverter) addRule(name, rule string) string { return key } -func (sc *JSONSchemaConverter) formatGrammar() string { +const array = `arr ::= + "[\n" ( + realvalue + (",\n" realvalue)* + )? "]"` + +func (sc *JSONSchemaConverter) finalizeGrammar(maybeArray bool) string { var lines []string + // write down the computed rules. + // if maybeArray is true, we need to add the array rule and slightly tweak the root rule for name, rule := range sc.rules { + if maybeArray && name == "root" { + name = "realvalue" + } lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule)) } + + if maybeArray { + lines = append(lines, fmt.Sprintf("%s ::= %s", "root", "arr | realvalue")) + lines = append(lines, array) + } + return strings.Join(lines, "\n") } @@ -234,15 +251,15 @@ func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[strin return def } -func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}) string { +func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, maybeArray bool) string { sc.visit(schema, "", schema) - return sc.formatGrammar() + return sc.finalizeGrammar(maybeArray) } -func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte) string { +func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte, maybeArray bool) string { var schema map[string]interface{} _ = json.Unmarshal(b, &schema) - return sc.Grammar(schema) + return sc.Grammar(schema, maybeArray) } func jsonString(v interface{}) string { @@ -275,7 +292,7 @@ type JSONFunctionStructure struct { Defs map[string]interface{} `json:"$defs,omitempty"` } -func (j JSONFunctionStructure) Grammar(propOrder string) string { +func (j JSONFunctionStructure) Grammar(propOrder string, maybeArray bool) string { dat, _ := json.Marshal(j) - return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat) + return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat, maybeArray) } diff --git a/pkg/grammar/json_schema_test.go b/pkg/grammar/json_schema_test.go index 9d4086cbf12e..39d2a4d57886 100644 --- a/pkg/grammar/json_schema_test.go +++ b/pkg/grammar/json_schema_test.go @@ -52,13 +52,32 @@ string ::= "\"" ( [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) )* "\"" space +root-1-function ::= "\"search\""` + + inputResult2 = `root-0-function ::= "\"create_event\"" +root-0 ::= "{" space "\"arguments\"" space ":" space root-0-arguments "," space "\"function\"" space ":" space root-0-function "}" space +root-1-arguments ::= "{" space "\"query\"" space ":" space string "}" space +realvalue ::= root-0 | root-1 +root ::= arr | realvalue +space ::= " "? +root-0-arguments ::= "{" space "\"date\"" space ":" space string "," space "\"time\"" space ":" space string "," space "\"title\"" space ":" space string "}" space +root-1 ::= "{" space "\"arguments\"" space ":" space root-1-arguments "," space "\"function\"" space ":" space root-1-function "}" space +string ::= "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) +)* "\"" space +arr ::= + "[\n" ( + realvalue + (",\n" realvalue)* + )? "]" root-1-function ::= "\"search\""` ) var _ = Describe("JSON schema grammar tests", func() { Context("JSON", func() { It("generates a valid grammar from JSON schema", func() { - grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1)) + grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1), false) results := strings.Split(inputResult1, "\n") for _, r := range results { if r != "" { @@ -103,7 +122,7 @@ var _ = Describe("JSON schema grammar tests", func() { }, }} - grammar := structuredGrammar.Grammar("") + grammar := structuredGrammar.Grammar("", false) results := strings.Split(inputResult1, "\n") for _, r := range results { if r != "" { @@ -112,5 +131,50 @@ var _ = Describe("JSON schema grammar tests", func() { } Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n")))) }) + + It("generates a valid grammar from JSON Objects for multiple function return", func() { + structuredGrammar := JSONFunctionStructure{ + OneOf: []Item{ + { + Type: "object", + Properties: Properties{ + Function: FunctionName{ + Const: "create_event", + }, + Arguments: Argument{ // this is OpenAI's parameter + Type: "object", + Properties: map[string]interface{}{ + "title": map[string]string{"type": "string"}, + "date": map[string]string{"type": "string"}, + "time": map[string]string{"type": "string"}, + }, + }, + }, + }, + { + Type: "object", + Properties: Properties{ + Function: FunctionName{ + Const: "search", + }, + Arguments: Argument{ + Type: "object", + Properties: map[string]interface{}{ + "query": map[string]string{"type": "string"}, + }, + }, + }, + }, + }} + + grammar := structuredGrammar.Grammar("", true) + results := strings.Split(inputResult2, "\n") + for _, r := range results { + if r != "" { + Expect(grammar).To(ContainSubstring(r)) + } + } + Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n"))), grammar) + }) }) }) diff --git a/pkg/grpc/backend.go b/pkg/grpc/backend.go index ae8ffc5fe714..22933d584a07 100644 --- a/pkg/grpc/backend.go +++ b/pkg/grpc/backend.go @@ -2,7 +2,8 @@ package grpc import ( "context" - "github.com/go-skynet/LocalAI/api/schema" + + "github.com/go-skynet/LocalAI/core/schema" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" "google.golang.org/grpc" ) diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go index 739d1cbbe6bb..89c8785e6b8c 100644 --- a/pkg/grpc/base/base.go +++ b/pkg/grpc/base/base.go @@ -6,7 +6,7 @@ import ( "fmt" "os" - "github.com/go-skynet/LocalAI/api/schema" + "github.com/go-skynet/LocalAI/core/schema" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" gopsutil "github.com/shirou/gopsutil/v3/process" ) diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index 5e97ea73e068..9058db05e018 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -7,7 +7,7 @@ import ( "sync" "time" - "github.com/go-skynet/LocalAI/api/schema" + "github.com/go-skynet/LocalAI/core/schema" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" diff --git a/pkg/grpc/embed.go b/pkg/grpc/embed.go index b9ab551f63c3..228b1df56bef 100644 --- a/pkg/grpc/embed.go +++ b/pkg/grpc/embed.go @@ -2,11 +2,12 @@ package grpc import ( "context" - "github.com/go-skynet/LocalAI/api/schema" + "time" + + "github.com/go-skynet/LocalAI/core/schema" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" "google.golang.org/grpc" "google.golang.org/grpc/metadata" - "time" ) var _ Backend = new(embedBackend) diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index a76261c15ce9..1cc7cb3d8762 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -1,7 +1,7 @@ package grpc import ( - "github.com/go-skynet/LocalAI/api/schema" + "github.com/go-skynet/LocalAI/core/schema" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" ) diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 37c2a603a634..bea32fb72a4c 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -33,6 +33,7 @@ type ChatMessageTemplateData struct { SystemPrompt string Role string RoleName string + FunctionName string Content string MessageIndex int } diff --git a/pkg/utils/path.go b/pkg/utils/path.go index 05481d2cc2a5..f95b0138133a 100644 --- a/pkg/utils/path.go +++ b/pkg/utils/path.go @@ -3,6 +3,7 @@ package utils import ( "fmt" "path/filepath" + "strings" ) func inTrustedRoot(path string, trustedRoot string) error { @@ -20,3 +21,14 @@ func VerifyPath(path, basePath string) error { c := filepath.Clean(filepath.Join(basePath, path)) return inTrustedRoot(c, filepath.Clean(basePath)) } + +// SanitizeFileName sanitizes the given filename +func SanitizeFileName(fileName string) string { + // filepath.Clean to clean the path + cleanName := filepath.Clean(fileName) + // filepath.Base to ensure we only get the final element, not any directory path + baseName := filepath.Base(cleanName) + // Replace any remaining tricky characters that might have survived cleaning + safeName := strings.ReplaceAll(baseName, "..", "") + return safeName +} diff --git a/tests/integration/reflect_test.go b/tests/integration/reflect_test.go index c0fe7096a1d8..bf3f8a5bc1ac 100644 --- a/tests/integration/reflect_test.go +++ b/tests/integration/reflect_test.go @@ -3,7 +3,7 @@ package integration_test import ( "reflect" - config "github.com/go-skynet/LocalAI/api/config" + config "github.com/go-skynet/LocalAI/core/config" model "github.com/go-skynet/LocalAI/pkg/model" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega"