Skip to content

Commit 6b8e1c8

Browse files
committed
Added self-correcting assistant functionality to diagnose, generate new SQL, and retry failed queries.y
1 parent bebab0e commit 6b8e1c8

File tree

4 files changed

+136
-16
lines changed

4 files changed

+136
-16
lines changed

api-server/api/index.py

Lines changed: 107 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
from flask import Flask, Request, Response, jsonify, request, make_response
33
import dotenv
44
from modules import db, llm, emb, instruments
5+
from modules.turbo4 import Turbo4
56

67
import os
78

89
from modules.models import TurboTool
10+
from psycopg2 import Error as PostgresError
911

1012
app = Flask(__name__)
1113

@@ -34,17 +36,97 @@ def make_cors_response():
3436
return response
3537

3638

39+
# ---------------- Self Correcting Assistant ----------------
40+
41+
42+
def self_correcting_assistant(
43+
db: db.PostgresManager,
44+
agent_instruments: instruments.AgentInstruments,
45+
tools: TurboTool,
46+
error: PostgresError,
47+
):
48+
# reset db - to unblock transactions
49+
db.roll_back()
50+
51+
all_table_definitions = db.get_table_definitions_for_prompt()
52+
53+
print(f"Loaded all table definitions")
54+
55+
# ------ File prep
56+
57+
file_path = agent_instruments.self_correcting_table_def_file
58+
59+
# write all_table_definitions to file
60+
with open(file_path, "w") as f:
61+
f.write(all_table_definitions)
62+
63+
files_to_upload = [file_path]
64+
65+
sql_query = open(agent_instruments.sql_query_file).read()
66+
67+
# ------ Prompts
68+
69+
output_file_path = agent_instruments.run_sql_results_file
70+
71+
diagnosis_prompt = f"Given the table_definitions.sql file, the following SQL_ERROR, and the SQL_QUERY, describe the most likely cause of the error. Think step by step.\n\nSQL_ERROR: {error}\n\nSQL_QUERY: {sql_query}"
72+
73+
generation_prompt = (
74+
f"Based on your diagnosis, generate a new SQL query that will run successfully."
75+
)
76+
77+
run_sql_prompt = "Use the run_sql function to run the SQL you've just generated."
78+
79+
assistant_name = "SQL Self Correction"
80+
81+
turbo4_assistant = Turbo4().get_or_create_assistant(assistant_name)
82+
83+
print(f"Generated Assistant: {assistant_name}")
84+
85+
file_ids = turbo4_assistant.upsert_files(files_to_upload)
86+
87+
print(f"Uploaded files: {file_ids}")
88+
89+
print(f"Running Self Correction Assistant...")
90+
91+
(
92+
turbo4_assistant.set_instructions(
93+
"You're an elite SQL developer. You generate the most concise and performant SQL queries. You review failed queries and generate new SQL queries to fix them."
94+
)
95+
.enable_retrieval()
96+
.equip_tools(tools)
97+
.make_thread()
98+
# 1/3 STEP PATTERN: diagnose
99+
.add_message(diagnosis_prompt, file_ids=file_ids)
100+
.run_thread()
101+
.spy_on_assistant(agent_instruments.make_agent_chat_file(assistant_name))
102+
# 2/3 STEP PATTERN: generate
103+
.add_message(generation_prompt)
104+
.run_thread()
105+
.spy_on_assistant(agent_instruments.make_agent_chat_file(assistant_name))
106+
# 3/3 STEP PATTERN: execute
107+
.add_message(run_sql_prompt)
108+
.run_thread(toolbox=[tools[0].name])
109+
.spy_on_assistant(agent_instruments.make_agent_chat_file(assistant_name))
110+
# clean up, logging, reporting, cost
111+
.run_validation(agent_instruments.validate_file_exists(output_file_path))
112+
.spy_on_assistant(agent_instruments.make_agent_chat_file(assistant_name))
113+
.get_costs_and_tokens(agent_instruments.make_agent_cost_file(assistant_name))
114+
)
115+
116+
pass
117+
118+
37119
# ---------------- Primary Endpoint ----------------
38120

39121

40122
@app.route("/prompt", methods=["POST", "OPTIONS"])
41123
def prompt():
42-
response = make_cors_response()
43-
44124
# Set CORS headers for the main request
125+
response = make_cors_response()
45126
if request.method == "OPTIONS":
46127
return response
47128

129+
# Get access to db, state, and functions
48130
with instruments.PostgresAgentInstruments(DB_URL, "prompt-endpoint") as (
49131
agent_instruments,
50132
db,
@@ -59,7 +141,10 @@ def prompt():
59141
)
60142

61143
if len(similar_tables) == 0:
62-
return jsonify({"error": "No similar tables found."})
144+
print(f"No similar tables found for prompt: {base_prompt}")
145+
response.status_code = 400
146+
response.data = "No similar tables found."
147+
return response
63148

64149
print("similar_tables", similar_tables)
65150

@@ -73,7 +158,7 @@ def prompt():
73158
similar_tables,
74159
)
75160

76-
# ---------------- Run Data Team - Generate SQL & Results ----------------
161+
# ---------------- Run 2 Agent Team - Generate SQL & Results ----------------
77162

78163
tools = [
79164
TurboTool("run_sql", llm.run_sql_tool_config, agent_instruments.run_sql),
@@ -84,14 +169,24 @@ def prompt():
84169
model="gpt-4-1106-preview",
85170
instructions="You're an elite SQL developer. You generate the most concise and performant SQL queries.",
86171
)
87-
llm.prompt_func(
88-
"Use the run_sql function to run the SQL you've just generated: "
89-
+ sql_response,
90-
model="gpt-4-1106-preview",
91-
instructions="You're an elite SQL developer. You generate the most concise and performant SQL queries.",
92-
turbo_tools=tools,
93-
)
94-
agent_instruments.validate_run_sql()
172+
try:
173+
llm.prompt_func(
174+
"Use the run_sql function to run the SQL you've just generated: "
175+
+ sql_response,
176+
model="gpt-4-1106-preview",
177+
instructions="You're an elite SQL developer. You generate the most concise and performant SQL queries.",
178+
turbo_tools=tools,
179+
)
180+
agent_instruments.validate_run_sql()
181+
except PostgresError as e:
182+
print(
183+
f"Received PostgresError -> Running Self Correction Team To Resolve: {e}"
184+
)
185+
186+
# ---------------- Run Self Correction Team - Diagnosis, Generate New SQL, Retry ----------------
187+
self_correcting_assistant(db, agent_instruments, tools, e)
188+
189+
print(f"Self Correction Team Complete.")
95190

96191
# ---------------- Read result files and respond ----------------
97192

api-server/api/modules/db.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,6 @@ def get_related_tables(self, table_list, n=2):
162162
related_tables_list = list(set(related_tables_list))
163163

164164
return related_tables_list
165+
166+
def roll_back(self):
167+
self.conn.rollback()

api-server/api/modules/instruments.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
from modules.db import PostgresManager
23
from modules import file
34
import os
@@ -112,12 +113,20 @@ def run_sql_results_file(self):
112113
def sql_query_file(self):
113114
return self.get_file_path("sql_query.sql")
114115

116+
@property
117+
def self_correcting_table_def_file(self):
118+
return self.get_file_path("table_definitions.sql")
119+
115120
# -------------------------- Agent Functions -------------------------- #
116121

117122
def run_sql(self, sql: str) -> str:
118123
"""
119124
Run a SQL query against the postgres database
120125
"""
126+
127+
with open(self.sql_query_file, "w") as f:
128+
f.write(sql)
129+
121130
results_as_json = self.db.run_sql(sql)
122131

123132
fname = self.run_sql_results_file
@@ -126,9 +135,6 @@ def run_sql(self, sql: str) -> str:
126135
with open(fname, "w") as f:
127136
f.write(results_as_json)
128137

129-
with open(self.sql_query_file, "w") as f:
130-
f.write(sql)
131-
132138
return "Successfully delivered results to json file"
133139

134140
def validate_run_sql(self):
@@ -171,3 +177,10 @@ def validate_innovation_files(self):
171177
return False, f"File {fname} is empty"
172178

173179
return True, ""
180+
181+
def validate_file_exists(self, file: str):
182+
def file_exists():
183+
if not os.path.exists(file):
184+
raise Exception(f"File {file} does not exist")
185+
186+
return file_exists

api-server/api/modules/models.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from dataclasses import dataclass
1+
from dataclasses import dataclass, field
2+
import time
23
from typing import Callable
34

45

@@ -7,3 +8,11 @@ class TurboTool:
78
name: str
89
config: dict
910
function: Callable
11+
12+
13+
@dataclass
14+
class Chat:
15+
from_name: str
16+
to_name: str
17+
message: str
18+
created: int = field(default_factory=time.time)

0 commit comments

Comments
 (0)