Skip to content

Commit ede4114

Browse files
authored
Made function call streamable (#98)
* made function call streamable * Revert `FunctionCall` and Introduce `FunctionCallStream`
1 parent 263eb70 commit ede4114

File tree

2 files changed

+90
-60
lines changed

2 files changed

+90
-60
lines changed

async-openai/src/types/types.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -869,12 +869,17 @@ pub type ChatCompletionResponseStream =
869869
Pin<Box<dyn Stream<Item = Result<CreateChatCompletionStreamResponse, OpenAIError>> + Send>>;
870870

871871
// For reason (not documented by OpenAI) the response from stream is different
872+
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
873+
pub struct FunctionCallStream {
874+
pub name: Option<String>,
875+
pub arguments: Option<String>,
876+
}
872877

873878
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
874879
pub struct ChatCompletionStreamResponseDelta {
875880
pub role: Option<Role>,
876881
pub content: Option<String>,
877-
pub function_call: Option<FunctionCall>,
882+
pub function_call: Option<FunctionCallStream>,
878883
}
879884

880885
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]

examples/function-call-stream/src/main.rs

Lines changed: 84 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use async_openai::{
1212

1313
use futures::StreamExt;
1414
use serde_json::json;
15+
use async_openai::config::OpenAIConfig;
1516

1617
#[tokio::main]
1718
async fn main() -> Result<(), Box<dyn Error>> {
@@ -42,71 +43,95 @@ async fn main() -> Result<(), Box<dyn Error>> {
4243
.function_call("auto")
4344
.build()?;
4445

45-
// the first response from GPT is just the json response containing the function that was called
46-
// and the model-generated arguments for that function (don't stream this)
47-
let response = client
48-
.chat()
49-
.create(request)
50-
.await?
51-
.choices
52-
.get(0)
53-
.unwrap()
54-
.message
55-
.clone();
56-
57-
if let Some(function_call) = response.function_call {
58-
let mut available_functions: HashMap<&str, fn(&str, &str) -> serde_json::Value> =
59-
HashMap::new();
60-
available_functions.insert("get_current_weather", get_current_weather);
61-
62-
let function_name = function_call.name;
63-
let function_args: serde_json::Value = function_call.arguments.parse().unwrap();
64-
65-
let location = function_args["location"].as_str().unwrap();
66-
let unit = "fahrenheit"; // why doesn't the model return a unit argument?
67-
let function = available_functions.get(function_name.as_str()).unwrap();
68-
let function_response = function(location, unit); // call the function
69-
70-
let message = vec![
71-
ChatCompletionRequestMessageArgs::default()
72-
.role(Role::User)
73-
.content("What's the weather like in Boston?")
74-
.build()?,
75-
ChatCompletionRequestMessageArgs::default()
76-
.role(Role::Function)
77-
.content(function_response.to_string())
78-
.name(function_name)
79-
.build()?,
80-
];
81-
82-
let request = CreateChatCompletionRequestArgs::default()
83-
.max_tokens(512u16)
84-
.model("gpt-3.5-turbo-0613")
85-
.messages(message)
86-
.build()?;
87-
88-
// Now stream received response from model, which essentially formats the function response
89-
let mut stream = client.chat().create_stream(request).await?;
90-
91-
let mut lock = stdout().lock();
92-
while let Some(result) = stream.next().await {
93-
match result {
94-
Ok(response) => {
95-
response.choices.iter().for_each(|chat_choice| {
96-
if let Some(ref content) = chat_choice.delta.content {
97-
write!(lock, "{}", content).unwrap();
46+
let mut stream = client.chat().create_stream(request).await?;
47+
48+
let mut fn_name = String::new();
49+
let mut fn_args = String::new();
50+
51+
let mut lock = stdout().lock();
52+
while let Some(result) = stream.next().await {
53+
match result {
54+
Ok(response) => {
55+
for chat_choice in response.choices {
56+
if let Some(fn_call) = &chat_choice.delta.function_call {
57+
writeln!(lock, "function_call: {:?}", fn_call).unwrap();
58+
if let Some(name) = &fn_call.name {
59+
fn_name = name.clone();
9860
}
99-
});
100-
}
101-
Err(err) => {
102-
writeln!(lock, "error: {err}").unwrap();
61+
if let Some(args) = &fn_call.arguments {
62+
fn_args.push_str(args);
63+
}
64+
}
65+
if let Some(finish_reason) = &chat_choice.finish_reason {
66+
if finish_reason == "function_call" {
67+
call_fn(&client, &fn_name, &fn_args).await?;
68+
}
69+
} else if let Some(content) = &chat_choice.delta.content {
70+
write!(lock, "{}", content).unwrap();
71+
}
10372
}
10473
}
105-
stdout().flush()?;
74+
Err(err) => {
75+
writeln!(lock, "error: {err}").unwrap();
76+
}
10677
}
107-
println!("{}", "\n");
78+
stdout().flush()?;
10879
}
10980

81+
82+
Ok(())
83+
}
84+
85+
async fn call_fn(client: &Client<OpenAIConfig>, name: &str, args: &str) -> Result<(), Box<dyn Error>> {
86+
let mut available_functions: HashMap<&str, fn(&str, &str) -> serde_json::Value> =
87+
HashMap::new();
88+
available_functions.insert("get_current_weather", get_current_weather);
89+
90+
let function_args: serde_json::Value = args.parse().unwrap();
91+
92+
let location = function_args["location"].as_str().unwrap();
93+
let unit = function_args["unit"].as_str().unwrap_or("fahrenheit");
94+
let function = available_functions.get(name).unwrap();
95+
let function_response = function(location, unit); // call the function
96+
97+
let message = vec![
98+
ChatCompletionRequestMessageArgs::default()
99+
.role(Role::User)
100+
.content("What's the weather like in Boston?")
101+
.build()?,
102+
ChatCompletionRequestMessageArgs::default()
103+
.role(Role::Function)
104+
.content(function_response.to_string())
105+
.name(name.clone())
106+
.build()?,
107+
];
108+
109+
let request = CreateChatCompletionRequestArgs::default()
110+
.max_tokens(512u16)
111+
.model("gpt-3.5-turbo-0613")
112+
.messages(message)
113+
.build()?;
114+
115+
// Now stream received response from model, which essentially formats the function response
116+
let mut stream = client.chat().create_stream(request).await?;
117+
118+
let mut lock = stdout().lock();
119+
while let Some(result) = stream.next().await {
120+
match result {
121+
Ok(response) => {
122+
response.choices.iter().for_each(|chat_choice| {
123+
if let Some(ref content) = chat_choice.delta.content {
124+
write!(lock, "{}", content).unwrap();
125+
}
126+
});
127+
}
128+
Err(err) => {
129+
writeln!(lock, "error: {err}").unwrap();
130+
}
131+
}
132+
stdout().flush()?;
133+
}
134+
println!("{}", "\n");
110135
Ok(())
111136
}
112137

0 commit comments

Comments
 (0)