Skip to content

Commit c7111ff

Browse files
committed
vid
1 parent eff4c8d commit c7111ff

File tree

9 files changed

+877
-46
lines changed

9 files changed

+877
-46
lines changed

β€Žpoetry.lockβ€Ž

Lines changed: 273 additions & 12 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

β€Žpostgres_da_ai_agent/agents/instruments.pyβ€Ž

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ def sync_messages(self, messages: list):
2929
def make_agent_chat_file(self, team_name: str):
3030
return os.path.join(self.root_dir, f"agent_chats_{team_name}.json")
3131

32+
def make_agent_cost_file(self, team_name: str):
33+
return os.path.join(self.root_dir, f"agent_cost_{team_name}.json")
34+
3235
@property
3336
def root_dir(self):
3437
return os.path.join(BASE_DIR, self.session_id)
@@ -105,6 +108,10 @@ def get_file_path(self, fname: str):
105108
def run_sql_results_file(self):
106109
return self.get_file_path("run_sql_results.json")
107110

111+
@property
112+
def sql_query_file(self):
113+
return self.get_file_path("sql_query.sql")
114+
108115
# -------------------------- Agent Functions -------------------------- #
109116

110117
def run_sql(self, sql: str) -> str:
@@ -119,6 +126,9 @@ def run_sql(self, sql: str) -> str:
119126
with open(fname, "w") as f:
120127
f.write(results_as_json)
121128

129+
with open(self.sql_query_file, "w") as f:
130+
f.write(sql)
131+
122132
return "Successfully delivered results to json file"
123133

124134
def validate_run_sql(self):
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
import json
2+
import os
3+
import openai
4+
import time
5+
from typing import Callable, Dict, Any, List, Optional, Union, Tuple
6+
import dotenv
7+
from dataclasses import dataclass, asdict
8+
from openai.types.beta import Thread, Assistant
9+
from openai.types import FileObject
10+
from openai.types.beta.threads.thread_message import ThreadMessage
11+
from openai.types.beta.threads.run_submit_tool_outputs_params import ToolOutput
12+
from postgres_da_ai_agent.modules import llm
13+
from postgres_da_ai_agent.types import Chat, TurboTool
14+
15+
dotenv.load_dotenv()
16+
17+
18+
class Turbo4:
19+
"""
20+
Simple, chainable class for the OpenAI's GPT-4 Assistant APIs.
21+
"""
22+
23+
def __init__(self):
24+
openai.api_key = os.environ.get("OPENAI_API_KEY")
25+
self.client = openai.OpenAI()
26+
self.map_function_tools: Dict[str, TurboTool] = {}
27+
self.current_thread_id = None
28+
self.thread_messages: List[ThreadMessage] = []
29+
self.local_messages = []
30+
self.assistant_id = None
31+
self.polling_interval = (
32+
0.5 # Interval in seconds to poll the API for thread run completion
33+
)
34+
self.model = "gpt-4-1106-preview"
35+
36+
@property
37+
def chat_messages(self) -> List[Chat]:
38+
return [
39+
Chat(
40+
from_name=msg.role,
41+
to_name="assistant" if msg.role == "user" else "user",
42+
message=llm.safe_get(msg.model_dump(), "content.0.text.value"),
43+
created=msg.created_at,
44+
)
45+
for msg in self.thread_messages
46+
]
47+
48+
@property
49+
def tool_config(self):
50+
return [tool.config for tool in self.map_function_tools.values()]
51+
52+
def run_validation(self, validation_func: Callable):
53+
print(f"run_validation({validation_func.__name__})")
54+
validation_func()
55+
return self
56+
57+
def spy_on_assistant(self, output_file: str):
58+
sorted_messages = sorted(
59+
self.chat_messages, key=lambda msg: msg.created, reverse=False
60+
)
61+
messages_as_json = [asdict(msg) for msg in sorted_messages]
62+
with open(output_file, "w") as f:
63+
json.dump(messages_as_json, f, indent=2)
64+
65+
return self
66+
67+
def get_costs_and_tokens(self, output_file: str) -> Tuple[float, float]:
68+
"""
69+
Get the estimated cost and token usage for the current thread.
70+
71+
https://openai.com/pricing
72+
73+
Open questions - how to calculate retrieval and code interpreter costs?
74+
"""
75+
76+
retrival_costs = 0
77+
code_interpreter_costs = 0
78+
79+
msgs = [
80+
llm.safe_get(msg.model_dump(), "content.0.text.value")
81+
for msg in self.thread_messages
82+
]
83+
joined_msgs = " ".join(msgs)
84+
85+
msg_cost, tokens = llm.estimate_price_and_tokens(joined_msgs)
86+
87+
with open(output_file, "w") as f:
88+
json.dump(
89+
{
90+
"cost": msg_cost,
91+
"tokens": tokens,
92+
},
93+
f,
94+
indent=2,
95+
)
96+
97+
return self
98+
99+
def set_instructions(self, instructions: str):
100+
print(f"set_instructions()")
101+
if self.assistant_id is None:
102+
raise ValueError(
103+
"No assistant has been created or retrieved. Call get_or_create_assistant() first."
104+
)
105+
# Update the assistant with the new instructions
106+
updated_assistant = self.client.beta.assistants.update(
107+
assistant_id=self.assistant_id, instructions=instructions
108+
)
109+
return self
110+
111+
def get_or_create_assistant(self, name: str, model: str = "gpt-4-1106-preview"):
112+
print(f"get_or_create_assistant({name}, {model})")
113+
# Retrieve the list of existing assistants
114+
assistants: List[Assistant] = self.client.beta.assistants.list().data
115+
116+
# Check if an assistant with the given name already exists
117+
for assistant in assistants:
118+
if assistant.name == name:
119+
self.assistant_id = assistant.id
120+
121+
# update model if different
122+
if assistant.model != model:
123+
print(f"Updating assistant model from {assistant.model} to {model}")
124+
self.client.beta.assistants.update(
125+
assistant_id=self.assistant_id, model=model
126+
)
127+
break
128+
else: # If no assistant was found with the name, create a new one
129+
assistant = self.client.beta.assistants.create(model=model, name=name)
130+
self.assistant_id = assistant.id
131+
132+
self.model = model
133+
134+
return self
135+
136+
def equip_tools(
137+
self, turbo_tools: List[TurboTool], equip_on_assistant: bool = False
138+
):
139+
print(f"equip_tools({turbo_tools}, {equip_on_assistant})")
140+
if self.assistant_id is None:
141+
raise ValueError(
142+
"No assistant has been created or retrieved. Call get_or_create_assistant() first."
143+
)
144+
145+
# Update the functions dictionary with the new tools
146+
self.map_function_tools = {tool.name: tool for tool in turbo_tools}
147+
148+
if equip_on_assistant:
149+
# Update the assistant with the new list of tools, replacing any existing tools
150+
updated_assistant = self.client.beta.assistants.update(
151+
tools=self.tool_config, assistant_id=self.assistant_id
152+
)
153+
154+
return self
155+
156+
def enable_retrieval(self):
157+
print(f"enable_retrieval()")
158+
if self.assistant_id is None:
159+
raise ValueError(
160+
"No assistant has been created or retrieved. Call get_or_create_assistant() first."
161+
)
162+
163+
# Update the assistant with the new list of tools, replacing any existing tools
164+
updated_assistant = self.client.beta.assistants.update(
165+
tools=[{"type": "retrieval"}], assistant_id=self.assistant_id
166+
)
167+
168+
return self
169+
170+
def make_thread(self):
171+
print(f"make_thread()")
172+
173+
if self.assistant_id is None:
174+
raise ValueError(
175+
"No assistant has been created. Call create_assistant() first."
176+
)
177+
178+
response = self.client.beta.threads.create()
179+
self.current_thread_id = response.id
180+
self.thread_messages = []
181+
return self
182+
183+
def add_message(self, message: str, refresh_threads: bool = False):
184+
print(f"add_message({message})")
185+
self.local_messages.append(message)
186+
self.client.beta.threads.messages.create(
187+
thread_id=self.current_thread_id, content=message, role="user"
188+
)
189+
if refresh_threads:
190+
self.load_threads()
191+
return self
192+
193+
def load_threads(self):
194+
self.thread_messages = self.client.beta.threads.messages.list(
195+
thread_id=self.current_thread_id
196+
).data
197+
198+
def list_steps(self):
199+
print(f"list_steps()")
200+
steps = self.client.beta.threads.runs.steps.list(
201+
thread_id=self.current_thread_id,
202+
run_id=self.run_id,
203+
)
204+
print("steps", steps)
205+
return steps
206+
207+
def run_thread(self, toolbox: Optional[List[str]] = None):
208+
print(f"run_thread({toolbox})")
209+
if self.current_thread_id is None:
210+
raise ValueError("No thread has been created. Call make_thread() first.")
211+
if self.local_messages == []:
212+
raise ValueError("No messages have been added to the thread.")
213+
214+
if toolbox is None:
215+
tools = None
216+
else:
217+
# get tools from toolbox
218+
tools = [self.map_function_tools[tool_name].config for tool_name in toolbox]
219+
220+
# throw if tool not found
221+
if len(tools) != len(toolbox):
222+
raise ValueError(
223+
f"Tool not found in toolbox. toolbox={toolbox}, tools={tools}. Make sure all tools are equipped on the assistant."
224+
)
225+
226+
# refresh current thread
227+
self.load_threads()
228+
229+
# Start the thread running
230+
run = self.client.beta.threads.runs.create(
231+
thread_id=self.current_thread_id,
232+
assistant_id=self.assistant_id,
233+
tools=tools,
234+
)
235+
self.run_id = run.id
236+
237+
# Polling mechanism to wait for thread's run completion or required actions
238+
while True:
239+
# self.list_steps()
240+
241+
run_status = self.client.beta.threads.runs.retrieve(
242+
thread_id=self.current_thread_id, run_id=self.run_id
243+
)
244+
if run_status.status == "requires_action":
245+
tool_outputs: List[ToolOutput] = []
246+
for (
247+
tool_call
248+
) in run_status.required_action.submit_tool_outputs.tool_calls:
249+
tool_function = tool_call.function
250+
tool_name = tool_function.name
251+
252+
# Check if tool_arguments is already a dictionary, if so, proceed directly
253+
if isinstance(tool_function.arguments, dict):
254+
tool_arguments = tool_function.arguments
255+
else:
256+
# Assume the arguments are JSON string and parse them
257+
tool_arguments = json.loads(tool_function.arguments)
258+
259+
print(f"run_thread() Calling {tool_name}({tool_arguments})")
260+
261+
# Assuming arguments are passed as a dictionary
262+
function_output = self.map_function_tools[tool_name].function(
263+
**tool_arguments
264+
)
265+
266+
tool_outputs.append(
267+
ToolOutput(tool_call_id=tool_call.id, output=function_output)
268+
)
269+
270+
# Submit the tool outputs back to the API
271+
self.client.beta.threads.runs.submit_tool_outputs(
272+
thread_id=self.current_thread_id,
273+
run_id=self.run_id,
274+
tool_outputs=[to for to in tool_outputs],
275+
)
276+
elif run_status.status == "completed":
277+
self.load_threads()
278+
return self
279+
280+
time.sleep(self.polling_interval) # Wait a little before polling again

β€Žpostgres_da_ai_agent/main.pyβ€Ž

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -111,30 +111,30 @@ def main():
111111

112112
# ----------- Data Eng Team: Based on a sql table definitions and a prompt create an sql statement and execute it -------------
113113

114-
# data_eng_orchestrator = agents.build_team_orchestrator(
115-
# "data_eng",
116-
# agent_instruments,
117-
# validate_results=agent_instruments.validate_run_sql,
118-
# )
119-
120-
# data_eng_conversation_result: ConversationResult = (
121-
# data_eng_orchestrator.sequential_conversation(prompt)
122-
# )
123-
124-
# match data_eng_conversation_result:
125-
# case ConversationResult(
126-
# success=True, cost=data_eng_cost, tokens=data_eng_tokens
127-
# ):
128-
# print(
129-
# f"βœ… Orchestrator was successful. Team: {data_eng_orchestrator.name}"
130-
# )
131-
# print(
132-
# f"πŸ’°πŸ“ŠπŸ€– {data_eng_orchestrator.name} Cost: {data_eng_cost}, tokens: {data_eng_tokens}"
133-
# )
134-
# case _:
135-
# print(
136-
# f"❌ Orchestrator failed. Team: {data_eng_orchestrator.name} Failed"
137-
# )
114+
data_eng_orchestrator = agents.build_team_orchestrator(
115+
"data_eng",
116+
agent_instruments,
117+
validate_results=agent_instruments.validate_run_sql,
118+
)
119+
120+
data_eng_conversation_result: ConversationResult = (
121+
data_eng_orchestrator.sequential_conversation(prompt)
122+
)
123+
124+
match data_eng_conversation_result:
125+
case ConversationResult(
126+
success=True, cost=data_eng_cost, tokens=data_eng_tokens
127+
):
128+
print(
129+
f"βœ… Orchestrator was successful. Team: {data_eng_orchestrator.name}"
130+
)
131+
print(
132+
f"πŸ’°πŸ“ŠπŸ€– {data_eng_orchestrator.name} Cost: {data_eng_cost}, tokens: {data_eng_tokens}"
133+
)
134+
case _:
135+
print(
136+
f"❌ Orchestrator failed. Team: {data_eng_orchestrator.name} Failed"
137+
)
138138

139139
# ----------- Data Insights Team: Based on sql table definitions and a prompt generate novel insights -------------
140140

0 commit comments

Comments
Β (0)