1
1
import datetime
2
- import sqlparse
3
2
from typing import List
4
3
5
4
import orjson
6
- from sqlalchemy import and_
7
- from sqlalchemy . orm import load_only
5
+ import sqlparse
6
+ from sqlalchemy import and_ , select
8
7
9
8
from apps .chat .models .chat_model import Chat , ChatRecord , CreateChat , ChatInfo , RenameChat , ChatQuestion
10
9
from apps .datasource .models .datasource import CoreDatasource
@@ -47,26 +46,29 @@ def delete_chat(session, chart_id) -> str:
47
46
48
47
49
48
def get_chat_chart_data (session : SessionDep , chart_record_id : int ):
50
- res = session .query (ChatRecord ).options (load_only (ChatRecord .data )).get (chart_record_id )
51
- if res :
49
+ stmt = select (ChatRecord .data ).where (and_ (ChatRecord .id == chart_record_id ))
50
+ res = session .execute (stmt )
51
+ for row in res :
52
52
try :
53
- return orjson .loads (res .data )
53
+ return orjson .loads (row .data )
54
54
except Exception :
55
55
pass
56
56
return {}
57
57
58
58
59
59
def get_chat_predict_data (session : SessionDep , chart_record_id : int ):
60
- res = session .query (ChatRecord ).options (load_only (ChatRecord .predict_data )).get (chart_record_id )
61
- if res :
60
+ stmt = select (ChatRecord .predict_data ).where (and_ (ChatRecord .id == chart_record_id ))
61
+ res = session .execute (stmt )
62
+ for row in res :
62
63
try :
63
- return orjson .loads (res .predict_data )
64
+ return orjson .loads (row .predict_data )
64
65
except Exception :
65
66
pass
66
- return ''
67
+ return {}
67
68
68
69
69
- def get_chat_with_records (session : SessionDep , chart_id : int , current_user : CurrentUser , current_assistant : CurrentAssistant ) -> ChatInfo :
70
+ def get_chat_with_records (session : SessionDep , chart_id : int , current_user : CurrentUser ,
71
+ current_assistant : CurrentAssistant ) -> ChatInfo :
70
72
chat = session .get (Chat , chart_id )
71
73
if not chat :
72
74
raise Exception (f"Chat with id { chart_id } not found" )
@@ -78,7 +80,7 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr
78
80
ds = out_ds_instance .get_ds (chat .datasource )
79
81
else :
80
82
ds = session .get (CoreDatasource , chat .datasource ) if chat .datasource else None
81
-
83
+
82
84
if not ds :
83
85
chat_info .datasource_exists = False
84
86
chat_info .datasource_name = 'Datasource not exist'
@@ -87,14 +89,25 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr
87
89
chat_info .datasource_name = ds .name
88
90
chat_info .ds_type = ds .type
89
91
90
- record_list = session .query (ChatRecord ).options (
91
- load_only (ChatRecord .id , ChatRecord .chat_id , ChatRecord .create_time , ChatRecord .finish_time ,
92
- ChatRecord .question , ChatRecord .sql_answer , ChatRecord .sql , ChatRecord .data ,
92
+ stmt = select (ChatRecord .id , ChatRecord .chat_id , ChatRecord .create_time , ChatRecord .finish_time ,
93
+ ChatRecord .question , ChatRecord .sql_answer , ChatRecord .sql ,
93
94
ChatRecord .chart_answer , ChatRecord .chart , ChatRecord .analysis , ChatRecord .predict ,
94
95
ChatRecord .datasource_select_answer , ChatRecord .analysis_record_id , ChatRecord .predict_record_id ,
95
96
ChatRecord .recommended_question , ChatRecord .first_chat ,
96
- ChatRecord .predict_data , ChatRecord .finish , ChatRecord .error )).filter (
97
- and_ (Chat .create_by == current_user .id , ChatRecord .chat_id == chart_id )).order_by (ChatRecord .create_time ).all ()
97
+ ChatRecord .finish , ChatRecord .error ).where (
98
+ and_ (ChatRecord .create_by == current_user .id , ChatRecord .chat_id == chart_id )).order_by (ChatRecord .create_time )
99
+ result = session .execute (stmt ).all ()
100
+ record_list : list [ChatRecord ] = []
101
+ for row in result :
102
+ record_list .append (
103
+ ChatRecord (id = row .id , chat_id = row .chat_id , create_time = row .create_time , finish_time = row .finish_time ,
104
+ question = row .question , sql_answer = row .sql_answer , sql = row .sql ,
105
+ chart_answer = row .chart_answer , chart = row .chart ,
106
+ analysis = row .analysis , predict = row .predict ,
107
+ datasource_select_answer = row .datasource_select_answer ,
108
+ analysis_record_id = row .analysis_record_id , predict_record_id = row .predict_record_id ,
109
+ recommended_question = row .recommended_question , first_chat = row .first_chat ,
110
+ finish = row .finish , error = row .error ))
98
111
99
112
result = list (map (format_record , record_list ))
100
113
0 commit comments