Skip to content

Commit dddd67d

Browse files
committed
Adhere to new return types when using tools instead of functions
1 parent ccf5faf commit dddd67d

File tree

2 files changed

+35
-79
lines changed

2 files changed

+35
-79
lines changed

api/openai/chat.go

Lines changed: 29 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -55,78 +55,6 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
5555
})
5656
close(responses)
5757
}
58-
59-
/*
60-
data:
61-
{
62-
"id":"chatcmpl-8sZrzBdLsWvnO2lX7Vz6glYAz8JMk",
63-
"object":"chat.completion.chunk",
64-
"created":1708018287,
65-
"model":"gpt-3.5-turbo-0613",
66-
"system_fingerprint":null,
67-
"choices":[
68-
{
69-
"index":0,
70-
"delta": {
71-
"role":"assistant",
72-
"content":null,
73-
"tool_calls":
74-
[
75-
{
76-
"index":0,
77-
"id":"call_kL07suiDkGzYbUCLMZZ5XUIU",
78-
"type":"function",
79-
"function":
80-
{
81-
"name":"get_current_weather",
82-
"arguments":""
83-
}
84-
}
85-
]
86-
},
87-
"logprobs":null,
88-
"finish_reason":null
89-
}]
90-
}
91-
92-
data: {"id":"chatcmpl-8sZrzBdLsWvnO2lX7Vz6glYAz8JMk","object":"chat.completion.chunk","created":1708018287,"model":"gpt-3.5-turbo-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"a
93-
rguments":"{\n"}}]},"logprobs":null,"finish_reason":null}]}
94-
95-
data: {"id":"chatcmpl-8sZrzBdLsWvnO2lX7Vz6glYAz8JMk","object":"chat.completion.chunk","created":1708018287,"model":"gpt-3.5-turbo-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"a
96-
rguments":" "}}]},"logprobs":null,"finish_reason":null}]}
97-
98-
data: {"id":"chatcmpl-8sZrzBdLsWvnO2lX7Vz6glYAz8JMk","object":"chat.completion.chunk","created":1708018287,"model":"gpt-3.5-turbo-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"a
99-
rguments":" \""}}]},"logprobs":null,"finish_reason":null}]}
100-
101-
data: {"id":"chatcmpl-8sZrzBdLsWvnO2lX7Vz6glYAz8JMk","object":"chat.completion.chunk","created":1708018287,"model":"gpt-3.5-turbo-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"a
102-
rguments":"location"}}]},"logprobs":null,"finish_reason":null}]}
103-
104-
data: {"id":"chatcmpl-8sZrzBdLsWvnO2lX7Vz6glYAz8JMk","object":"chat.completion.chunk","created":1708018287,"model":"gpt-3.5-turbo-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"a
105-
rguments":"\":"}}]},"logprobs":null,"finish_reason":null}]}
106-
107-
data: {"id":"chatcmpl-8sZrzBdLsWvnO2lX7Vz6glYAz8JMk","object":"chat.completion.chunk","created":1708018287,"model":"gpt-3.5-turbo-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"a
108-
rguments":" \""}}]},"logprobs":null,"finish_reason":null}]}
109-
110-
data: {"id":"chatcmpl-8sZrzBdLsWvnO2lX7Vz6glYAz8JMk","object":"chat.completion.chunk","created":1708018287,"model":"gpt-3.5-turbo-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"a
111-
rguments":"Boston"}}]},"logprobs":null,"finish_reason":null}]}
112-
113-
data: {"id":"chatcmpl-8sZrzBdLsWvnO2lX7Vz6glYAz8JMk","object":"chat.completion.chunk","created":1708018287,"model":"gpt-3.5-turbo-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"a
114-
rguments":","}}]},"logprobs":null,"finish_reason":null}]}
115-
116-
data: {"id":"chatcmpl-8sZrzBdLsWvnO2lX7Vz6glYAz8JMk","object":"chat.completion.chunk","created":1708018287,"model":"gpt-3.5-turbo-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"a
117-
rguments":" MA"}}]},"logprobs":null,"finish_reason":null}]}
118-
119-
data: {"id":"chatcmpl-8sZrzBdLsWvnO2lX7Vz6glYAz8JMk","object":"chat.completion.chunk","created":1708018287,"model":"gpt-3.5-turbo-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"a
120-
rguments":"\"\n"}}]},"logprobs":null,"finish_reason":null}]}
121-
122-
data: {"id":"chatcmpl-8sZrzBdLsWvnO2lX7Vz6glYAz8JMk","object":"chat.completion.chunk","created":1708018287,"model":"gpt-3.5-turbo-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"a
123-
rguments":"}"}}]},"logprobs":null,"finish_reason":null}]}
124-
125-
data: {"id":"chatcmpl-8sZrzBdLsWvnO2lX7Vz6glYAz8JMk","object":"chat.completion.chunk","created":1708018287,"model":"gpt-3.5-turbo-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool
126-
_calls"}]}
127-
128-
data: [DONE]
129-
*/
13058
processTools := func(prompt string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) {
13159
ComputeChoices(req, prompt, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
13260
ss := map[string]interface{}{}
@@ -391,7 +319,6 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
391319
}
392320

393321
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
394-
395322
usage := &schema.OpenAIUsage{}
396323

397324
for ev := range responses {
@@ -488,11 +415,35 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
488415
fineTunedResponse := backend.Finetune(*config, predInput, prediction.Response)
489416
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &fineTunedResponse}})
490417
} else {
491-
// otherwise reply with the function call
492-
*c = append(*c, schema.Choice{
493-
FinishReason: "function_call",
494-
Message: &schema.Message{Role: "assistant", FunctionCall: ss},
495-
})
418+
if len(input.Tools) > 0 {
419+
// Result is different in the case we have a tool call
420+
*c = append(*c, schema.Choice{
421+
FinishReason: "tool_calls",
422+
Message: &schema.Message{
423+
Role: "assistant",
424+
ToolCalls: []schema.ToolCall{
425+
{
426+
ID: id,
427+
Type: "function",
428+
FunctionCall: schema.FunctionCall{
429+
Name: name,
430+
Arguments: args,
431+
},
432+
},
433+
},
434+
FunctionCall: ss,
435+
},
436+
})
437+
} else {
438+
// otherwise reply with the function call
439+
*c = append(*c, schema.Choice{
440+
FinishReason: "function_call",
441+
Message: &schema.Message{
442+
Role: "assistant",
443+
FunctionCall: ss,
444+
},
445+
})
446+
}
496447
}
497448

498449
return

api/openai/request.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
fiberContext "github.com/go-skynet/LocalAI/api/ctx"
1414
options "github.com/go-skynet/LocalAI/api/options"
1515
"github.com/go-skynet/LocalAI/api/schema"
16+
"github.com/go-skynet/LocalAI/pkg/grammar"
1617
model "github.com/go-skynet/LocalAI/pkg/model"
1718
"github.com/gofiber/fiber/v2"
1819
"github.com/rs/zerolog/log"
@@ -143,7 +144,11 @@ func updateRequestConfig(config *config.Config, input *schema.OpenAIRequest) {
143144
}
144145

145146
if input.ToolsChoice != nil {
146-
input.FunctionCall = input.ToolsChoice
147+
var toolChoice grammar.Tool
148+
json.Unmarshal([]byte(input.ToolsChoice.(string)), &toolChoice)
149+
input.FunctionCall = map[string]interface{}{
150+
"name": toolChoice.Function.Name,
151+
}
147152
}
148153

149154
// Decode each request's message content

0 commit comments

Comments
 (0)