@@ -12,6 +12,7 @@ use async_openai::{
1212
1313use futures:: StreamExt ;
1414use serde_json:: json;
15+ use async_openai:: config:: OpenAIConfig ;
1516
1617#[ tokio:: main]
1718async fn main ( ) -> Result < ( ) , Box < dyn Error > > {
@@ -42,71 +43,95 @@ async fn main() -> Result<(), Box<dyn Error>> {
4243 . function_call ( "auto" )
4344 . build ( ) ?;
4445
45- // the first response from GPT is just the json response containing the function that was called
46- // and the model-generated arguments for that function (don't stream this)
47- let response = client
48- . chat ( )
49- . create ( request)
50- . await ?
51- . choices
52- . get ( 0 )
53- . unwrap ( )
54- . message
55- . clone ( ) ;
56-
57- if let Some ( function_call) = response. function_call {
58- let mut available_functions: HashMap < & str , fn ( & str , & str ) -> serde_json:: Value > =
59- HashMap :: new ( ) ;
60- available_functions. insert ( "get_current_weather" , get_current_weather) ;
61-
62- let function_name = function_call. name ;
63- let function_args: serde_json:: Value = function_call. arguments . parse ( ) . unwrap ( ) ;
64-
65- let location = function_args[ "location" ] . as_str ( ) . unwrap ( ) ;
66- let unit = "fahrenheit" ; // why doesn't the model return a unit argument?
67- let function = available_functions. get ( function_name. as_str ( ) ) . unwrap ( ) ;
68- let function_response = function ( location, unit) ; // call the function
69-
70- let message = vec ! [
71- ChatCompletionRequestMessageArgs :: default ( )
72- . role( Role :: User )
73- . content( "What's the weather like in Boston?" )
74- . build( ) ?,
75- ChatCompletionRequestMessageArgs :: default ( )
76- . role( Role :: Function )
77- . content( function_response. to_string( ) )
78- . name( function_name)
79- . build( ) ?,
80- ] ;
81-
82- let request = CreateChatCompletionRequestArgs :: default ( )
83- . max_tokens ( 512u16 )
84- . model ( "gpt-3.5-turbo-0613" )
85- . messages ( message)
86- . build ( ) ?;
87-
88- // Now stream received response from model, which essentially formats the function response
89- let mut stream = client. chat ( ) . create_stream ( request) . await ?;
90-
91- let mut lock = stdout ( ) . lock ( ) ;
92- while let Some ( result) = stream. next ( ) . await {
93- match result {
94- Ok ( response) => {
95- response. choices . iter ( ) . for_each ( |chat_choice| {
96- if let Some ( ref content) = chat_choice. delta . content {
97- write ! ( lock, "{}" , content) . unwrap ( ) ;
46+ let mut stream = client. chat ( ) . create_stream ( request) . await ?;
47+
48+ let mut fn_name = String :: new ( ) ;
49+ let mut fn_args = String :: new ( ) ;
50+
51+ let mut lock = stdout ( ) . lock ( ) ;
52+ while let Some ( result) = stream. next ( ) . await {
53+ match result {
54+ Ok ( response) => {
55+ for chat_choice in response. choices {
56+ if let Some ( fn_call) = & chat_choice. delta . function_call {
57+ writeln ! ( lock, "function_call: {:?}" , fn_call) . unwrap ( ) ;
58+ if let Some ( name) = & fn_call. name {
59+ fn_name = name. clone ( ) ;
9860 }
99- } ) ;
100- }
101- Err ( err) => {
102- writeln ! ( lock, "error: {err}" ) . unwrap ( ) ;
61+ if let Some ( args) = & fn_call. arguments {
62+ fn_args. push_str ( args) ;
63+ }
64+ }
65+ if let Some ( finish_reason) = & chat_choice. finish_reason {
66+ if finish_reason == "function_call" {
67+ call_fn ( & client, & fn_name, & fn_args) . await ?;
68+ }
69+ } else if let Some ( content) = & chat_choice. delta . content {
70+ write ! ( lock, "{}" , content) . unwrap ( ) ;
71+ }
10372 }
10473 }
105- stdout ( ) . flush ( ) ?;
74+ Err ( err) => {
75+ writeln ! ( lock, "error: {err}" ) . unwrap ( ) ;
76+ }
10677 }
107- println ! ( "{}" , " \n " ) ;
78+ stdout ( ) . flush ( ) ? ;
10879 }
10980
81+
82+ Ok ( ( ) )
83+ }
84+
85+ async fn call_fn ( client : & Client < OpenAIConfig > , name : & str , args : & str ) -> Result < ( ) , Box < dyn Error > > {
86+ let mut available_functions: HashMap < & str , fn ( & str , & str ) -> serde_json:: Value > =
87+ HashMap :: new ( ) ;
88+ available_functions. insert ( "get_current_weather" , get_current_weather) ;
89+
90+ let function_args: serde_json:: Value = args. parse ( ) . unwrap ( ) ;
91+
92+ let location = function_args[ "location" ] . as_str ( ) . unwrap ( ) ;
93+ let unit = function_args[ "unit" ] . as_str ( ) . unwrap_or ( "fahrenheit" ) ;
94+ let function = available_functions. get ( name) . unwrap ( ) ;
95+ let function_response = function ( location, unit) ; // call the function
96+
97+ let message = vec ! [
98+ ChatCompletionRequestMessageArgs :: default ( )
99+ . role( Role :: User )
100+ . content( "What's the weather like in Boston?" )
101+ . build( ) ?,
102+ ChatCompletionRequestMessageArgs :: default ( )
103+ . role( Role :: Function )
104+ . content( function_response. to_string( ) )
105+ . name( name. clone( ) )
106+ . build( ) ?,
107+ ] ;
108+
109+ let request = CreateChatCompletionRequestArgs :: default ( )
110+ . max_tokens ( 512u16 )
111+ . model ( "gpt-3.5-turbo-0613" )
112+ . messages ( message)
113+ . build ( ) ?;
114+
115+ // Now stream received response from model, which essentially formats the function response
116+ let mut stream = client. chat ( ) . create_stream ( request) . await ?;
117+
118+ let mut lock = stdout ( ) . lock ( ) ;
119+ while let Some ( result) = stream. next ( ) . await {
120+ match result {
121+ Ok ( response) => {
122+ response. choices . iter ( ) . for_each ( |chat_choice| {
123+ if let Some ( ref content) = chat_choice. delta . content {
124+ write ! ( lock, "{}" , content) . unwrap ( ) ;
125+ }
126+ } ) ;
127+ }
128+ Err ( err) => {
129+ writeln ! ( lock, "error: {err}" ) . unwrap ( ) ;
130+ }
131+ }
132+ stdout ( ) . flush ( ) ?;
133+ }
134+ println ! ( "{}" , "\n " ) ;
110135 Ok ( ( ) )
111136}
112137
0 commit comments