Skip to content
Merged

Chatbot #1716

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,12 @@ dependencies = [
"winapi 0.3.9",
]

[[package]]
name = "dotenv"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f"

[[package]]
name = "downcast"
version = "0.11.0"
Expand Down Expand Up @@ -2574,12 +2580,13 @@ dependencies = [

[[package]]
name = "integritee-cli"
version = "0.16.6"
version = "0.16.7"
dependencies = [
"array-bytes 6.1.0",
"base58",
"chrono 0.4.26",
"clap 3.2.25",
"dotenv",
"enclave-bridge-primitives",
"env_logger 0.9.3",
"hdrhistogram",
Expand Down Expand Up @@ -2632,7 +2639,7 @@ dependencies = [

[[package]]
name = "integritee-service"
version = "0.16.6"
version = "0.16.7"
dependencies = [
"anyhow",
"async-trait",
Expand Down
3 changes: 2 additions & 1 deletion cli/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "integritee-cli"
version = "0.16.6"
version = "0.16.7"
authors = ["Integritee AG <[email protected]>"]
edition = "2021"

Expand All @@ -10,6 +10,7 @@ base58 = "0.2"
chrono = "*"
clap = { version = "3.1.6", features = ["derive"] }
codec = { version = "3.0.0", package = "parity-scale-codec", features = ["derive"] }
dotenv = "0.15"
env_logger = "0.9"
hdrhistogram = "7.5.0"
hex = "0.4.2"
Expand Down
2 changes: 2 additions & 0 deletions cli/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ mod benchmark;
mod command_utils;
#[cfg(feature = "evm")]
mod evm;
mod llm_handler;
mod notes_handler;
#[cfg(feature = "teeracle")]
mod oracle;
mod trusted_assets;
Expand Down
176 changes: 176 additions & 0 deletions cli/src/llm_handler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
Copyright 2021 Integritee AG

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

*/

use codec::Decode;
use ita_stf::TrustedCall;
use itp_types::{AccountId, Moment};
use log::{debug, trace, warn};
use pallet_notes::{TimestampedTrustedNote, TrustedNote};
use prometheus::register_counter;
use reqwest::Client;
use serde::{Deserialize, Serialize};

// ChatGPT API types
#[derive(Serialize)]
struct ChatRequest<'a> {
model: &'a str,
messages: Vec<Message<'a>>,
max_tokens: u16,
}

#[derive(Debug, Serialize)]
struct Message<'a> {
role: &'a str,
content: &'a str,
}

#[derive(Deserialize)]
struct ChatResponse {
choices: Vec<Choice>,
usage: Option<Usage>,
}

#[derive(Deserialize)]
struct Choice {
message: MessageContent,
}

#[derive(Deserialize)]
struct MessageContent {
content: String,
}

#[derive(Deserialize)]
struct Usage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}

pub struct LLMHandler {
api_key: String,
metrics_prompt_tokens_counter: prometheus::Counter,
metrics_completion_tokens_counter: prometheus::Counter,
}

impl LLMHandler {
pub fn new(api_key: String) -> Self {
let metrics_prompt_tokens_counter =
register_counter!("llm_prompt_tokens_counter", "Number of used prompt tokens").unwrap();
let metrics_completion_tokens_counter =
register_counter!("llm_completion_tokens_counter", "Number of used completion tokens")
.unwrap();
LLMHandler { api_key, metrics_prompt_tokens_counter, metrics_completion_tokens_counter }
}

pub async fn process_ai_prompt(
&self,
prompt: String,
system_briefing: String,
model: String,
bot_account: &AccountId,
history: Vec<TimestampedTrustedNote<Moment>>,
) -> String {
let mut messages: Vec<Message> =
vec![Message { role: "system", content: system_briefing.as_str() }];
history.iter().for_each(|note| {
if let TrustedNote::SuccessfulTrustedCall(ref tc) = note.note {
if let Ok(TrustedCall::send_note(from, _to, msg)) =
TrustedCall::decode(&mut tc.as_slice())
{
let msg_str = String::from_utf8(msg).unwrap_or_else(|_| {
warn!("Failed to decode message as UTF-8, using empty string");
String::new()
});
if *bot_account == from {
messages.push(Message {
role: "assistant",
content: Box::leak(msg_str.into_boxed_str()),
});
} else {
messages.push(Message {
role: "user",
content: Box::leak(msg_str.into_boxed_str()),
});
}
} else {
warn!("Failed to decode TrustedCall::send_note from note: {:?}", note);
}
}
});
messages.push(Message { role: "user", content: prompt.as_str() });
trace!("Sending prompt to LLM: {:?}", messages);
let request_body = ChatRequest {
model: model.as_str(),
messages,
max_tokens: 70, // Roughly ≈ 140 characters
};
let client = Client::new();
let mut attempts = 0;
let response = loop {
match client
.post("https://api.openai.com/v1/chat/completions")
.bearer_auth(self.api_key.clone())
.json(&request_body)
.send()
.await
{
Ok(resp) =>
if resp.status().is_success() {
break resp
} else {
warn!("Received non-success status code: {}", resp.status());
return String::from("Error: Non-success status code received from LLM API")
},
Err(e) => {
attempts += 1;
warn!("Failed to send request to LLM API (attempt {}): {:?}", attempts, e);
if attempts >= 3 {
return String::from(
"Error: Failed to send request to LLM API after 3 attempts",
)
}
},
}
};

debug!("Got response from LLM: {:?}", response);

let json: ChatResponse = match response.json().await {
Ok(parsed_json) => parsed_json,
Err(e) => {
warn!("Failed to parse LLM response JSON: {:?}", e);
return String::from("Error: Failed to parse LLM response JSON")
},
};
if let Some(usage) = json.usage {
debug!(
"Token usage - Prompt tokens: {}, Completion tokens: {}, Total tokens: {}",
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
);
self.metrics_prompt_tokens_counter.inc_by(usage.prompt_tokens.into());
self.metrics_completion_tokens_counter.inc_by(usage.completion_tokens.into());
}

let prompt_reply = {
let content = json.choices[0].message.content.trim();
let cropped = &content.as_bytes()[..std::cmp::min(200, content.len())];
String::from_utf8_lossy(cropped).to_string()
};
prompt_reply
}
}
Loading
Loading