Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,11 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
}

// Default middleware config
app.Use(recover.New())

if !options.Debug {
app.Use(recover.New())
}

if options.Metrics != nil {
app.Use(metrics.APIMiddleware(options.Metrics))
}
Expand Down
280 changes: 191 additions & 89 deletions api/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,65 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
})
close(responses)
}
processTools := func(prompt string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {

result := ""
ComputeChoices(req, prompt, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
result += s
// TODO: Change generated BNF grammar to be compliant with the schema so we can
// stream the result token by token here.
return true
})

ss := map[string]interface{}{}
name, args := parseFunctionCall(result)
ss["name"], ss["arguments"] = name, args

initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{
Delta: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
Index: 0,
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
},
},
},
}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage

responses <- schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{
Delta: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
Index: 0,
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Arguments: args,
},
},
},
}}},
Object: "chat.completion.chunk",
}
close(responses)
}

return func(c *fiber.Ctx) error {
processFunctions := false
funcs := grammar.Functions{}
Expand Down Expand Up @@ -122,7 +181,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
}

// functions are not supported in stream mode (yet?)
toStream := input.Stream && !processFunctions
toStream := input.Stream

log.Debug().Msgf("Parameters: %+v", config)

Expand All @@ -145,13 +204,15 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
}
r := config.Roles[role]
contentExists := i.Content != nil && i.StringContent != ""

// First attempt to populate content via a chat message specific template
if config.TemplateConfig.ChatMessage != "" {
chatMessageData := model.ChatMessageTemplateData{
SystemPrompt: config.SystemPrompt,
Role: r,
RoleName: role,
Content: i.StringContent,
FunctionName: i.Name,
MessageIndex: messageIndex,
}
templatedChatMessage, err := o.Loader.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData)
Expand Down Expand Up @@ -254,13 +315,17 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
log.Debug().Msgf("Grammar: %+v", config.Grammar)
}

if toStream {
switch {
case toStream:
responses := make(chan schema.OpenAIResponse)

go process(predInput, input, config, o.Loader, responses)
if !processFunctions {
go process(predInput, input, config, o.Loader, responses)
} else {
go processTools(predInput, input, config, o.Loader, responses)
}

c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {

usage := &schema.OpenAIUsage{}

for ev := range responses {
Expand All @@ -278,13 +343,18 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
w.Flush()
}

finishReason := "stop"
if processFunctions && len(input.Tools) > 0 {
finishReason = "tool_calls"
}

resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
FinishReason: "stop",
FinishReason: finishReason,
Index: 0,
Delta: &schema.Message{Content: &emptyMessage},
}},
Expand All @@ -298,102 +368,134 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
w.Flush()
}))
return nil
}

result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) {
if processFunctions {
// As we have to change the result before processing, we can't stream the answer (yet?)
ss := map[string]interface{}{}
// This prevent newlines to break JSON parsing for clients
s = utils.EscapeNewLines(s)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)

// The grammar defines the function name as "function", while OpenAI returns "name"
func_name := ss["function"]
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
d, _ := json.Marshal(args)

ss["arguments"] = string(d)
ss["name"] = func_name

// if do nothing, reply with a message
if func_name == noActionName {
log.Debug().Msgf("nothing to do, computing a reply")

// If there is a message that the LLM already sends as part of the JSON reply, use it
arguments := map[string]interface{}{}
json.Unmarshal([]byte(d), &arguments)
m, exists := arguments["message"]
if exists {
switch message := m.(type) {
case string:
if message != "" {
log.Debug().Msgf("Reply received from LLM: %s", message)
message = backend.Finetune(*config, predInput, message)
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)

*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &message}})
return
default:
result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) {
if processFunctions {
ss := map[string]interface{}{}

name, args := parseFunctionCall(s)
ss["name"], ss["arguments"] = name, args

// if do nothing, reply with a message
if name == noActionName {
log.Debug().Msgf("nothing to do, computing a reply")

// If there is a message that the LLM already sends as part of the JSON reply, use it
arguments := map[string]interface{}{}
json.Unmarshal([]byte(args), &arguments)
m, exists := arguments["message"]
if exists {
switch message := m.(type) {
case string:
if message != "" {
log.Debug().Msgf("Reply received from LLM: %s", message)
message = backend.Finetune(*config, predInput, message)
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message)

*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &message}})
return
}
}
}
}

log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
// Note: This costs (in term of CPU) another computation
config.Grammar = ""
images := []string{}
for _, m := range input.Messages {
images = append(images, m.StringImages...)
}
predFunc, err := backend.ModelInference(input.Context, predInput, images, o.Loader, *config, o, nil)
if err != nil {
log.Error().Msgf("inference error: %s", err.Error())
return
}
log.Debug().Msgf("No action received from LLM, without a message, computing a reply")
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
// Note: This costs (in term of CPU) another computation
config.Grammar = ""
images := []string{}
for _, m := range input.Messages {
images = append(images, m.StringImages...)
}
predFunc, err := backend.ModelInference(input.Context, predInput, images, o.Loader, *config, o, nil)
if err != nil {
log.Error().Msgf("inference error: %s", err.Error())
return
}

prediction, err := predFunc()
if err != nil {
log.Error().Msgf("inference error: %s", err.Error())
return
prediction, err := predFunc()
if err != nil {
log.Error().Msgf("inference error: %s", err.Error())
return
}

fineTunedResponse := backend.Finetune(*config, predInput, prediction.Response)
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &fineTunedResponse}})
} else {
if len(input.Tools) > 0 {
// Result is different in the case we have a tool call
*c = append(*c, schema.Choice{
FinishReason: "tool_calls",
Message: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
Arguments: args,
},
},
},
FunctionCall: ss,
},
})
} else {
// otherwise reply with the function call
*c = append(*c, schema.Choice{
FinishReason: "function_call",
Message: &schema.Message{
Role: "assistant",
FunctionCall: ss,
},
})
}
}

fineTunedResponse := backend.Finetune(*config, predInput, prediction.Response)
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &fineTunedResponse}})
} else {
// otherwise reply with the function call
*c = append(*c, schema.Choice{
FinishReason: "function_call",
Message: &schema.Message{Role: "assistant", FunctionCall: ss},
})
return
}
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
}, nil)
if err != nil {
return err
}

return
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "chat.completion",
Usage: schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
},
}
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
}, nil)
if err != nil {
return err
}
respData, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", respData)

resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "chat.completion",
Usage: schema.OpenAIUsage{
PromptTokens: tokenUsage.Prompt,
CompletionTokens: tokenUsage.Completion,
TotalTokens: tokenUsage.Prompt + tokenUsage.Completion,
},
// Return the prediction in the response body
return c.JSON(resp)
}
respData, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", respData)

// Return the prediction in the response body
return c.JSON(resp)
}
}

func parseFunctionCall(llmresult string) (string, string) {
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
ss := map[string]interface{}{}
// This prevent newlines to break JSON parsing for clients
s := utils.EscapeNewLines(llmresult)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)

// The grammar defines the function name as "function", while OpenAI returns "name"
func_name := ss["function"]
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
d, _ := json.Marshal(args)

return func_name.(string), string(d)
}
15 changes: 15 additions & 0 deletions api/openai/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
fiberContext "github.com/go-skynet/LocalAI/api/ctx"
options "github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/api/schema"
"github.com/go-skynet/LocalAI/pkg/grammar"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
Expand Down Expand Up @@ -136,6 +137,20 @@ func updateRequestConfig(config *config.Config, input *schema.OpenAIRequest) {
}
}

if len(input.Tools) > 0 {
for _, tool := range input.Tools {
input.Functions = append(input.Functions, tool.Function)
}
}

if input.ToolsChoice != nil {
var toolChoice grammar.Tool
json.Unmarshal([]byte(input.ToolsChoice.(string)), &toolChoice)
input.FunctionCall = map[string]interface{}{
"name": toolChoice.Function.Name,
}
}

// Decode each request's message content
index := 0
for i, m := range input.Messages {
Expand Down
Loading