diff --git a/.github/workflows/package_and_push.yml b/.github/workflows/package_and_push.yml
index 3ec89e7c..f946d3c4 100644
--- a/.github/workflows/package_and_push.yml
+++ b/.github/workflows/package_and_push.yml
@@ -9,6 +9,11 @@ on:
description: 'Image Tag'
default: 'v0.9.5'
required: true
+ packageEE:
+ description: '是否打包企业版'
+ default: false
+ required: true
+ type: boolean
jobs:
package-and-push-to-aliyun-oss:
@@ -31,6 +36,7 @@ jobs:
ALIYUN_OSS_BUCKET_ENDPOINT: ${{ secrets.ALIYUN_OSS_BUCKET_ENDPOINT }}
ALIYUN_OSS_ACCESS_KEY: ${{ secrets.ALIYUN_OSS_ACCESS_KEY }}
ALIYUN_OSS_ACCESS_SECRET: ${{ secrets.ALIYUN_OSS_ACCESS_SECRET }}
+ PACKAGE_EE: ${{ github.event.inputs.packageEE }}
run: |
DOCKER_IMAGE=${ALIYUN_REGISTRY_HOST}/dataease/sqlbot
cd installer
@@ -76,12 +82,26 @@ jobs:
--exclude .git \
--exclude images \
--exclude docker \
- -czvf $package_online .
+ -czvf $package_offline .
+
+ if [[ ${PACKAGE_EE} == 'true' ]]; then
+ package_offline_ee="sqlbot-offline-installer-${TAG_NAME}${platform}-ee.tar.gz"
+ touch $package_offline_ee
+ tar --transform "s/^\./sqlbot-offline-installer-${TAG_NAME}${platform}-ee/" \
+ --exclude $package_online \
+ --exclude $package_offline \
+ --exclude $package_offline_ee \
+ --exclude .git \
+ -czvf $package_offline_ee .
+
+ ossutil cp -rf ${package_offline_ee} oss://$ALIYUN_OSS_BUCKET/sqlbot/${package_offline_ee} --access-key-id=$ALIYUN_OSS_ACCESS_KEY --access-key-secret=$ALIYUN_OSS_ACCESS_SECRET --endpoint=${ALIYUN_OSS_BUCKET_ENDPOINT}
+ fi
#Sync files to OSS
- ossutil cp -rf ${package_offline} oss://$ALIYUN_OSS_BUCKET/sqlbot/${package_offline} --access-key-id=$ALIYUN_OSS_ACCESS_KEY --access-key-secret=$ALIYUN_OSS_ACCESS_SECRET --endpoint=${ALIYUN_OSS_BUCKET_ENDPOINT}
+ ossutil cp -rf ${package_offline} oss://$ALIYUN_OSS_BUCKET/sqlbot/${package_offline} --access-key-id=$ALIYUN_OSS_ACCESS_KEY --access-key-secret=$ALIYUN_OSS_ACCESS_SECRET --endpoint=${ALIYUN_OSS_BUCKET_ENDPOINT}
ossutil cp -rf ${package_online} oss://$ALIYUN_OSS_BUCKET/sqlbot/${package_online} --access-key-id=$ALIYUN_OSS_ACCESS_KEY --access-key-secret=$ALIYUN_OSS_ACCESS_SECRET --endpoint=${ALIYUN_OSS_BUCKET_ENDPOINT}
+
diff --git a/README.md b/README.md
index 98e00c88..d3622ab9 100644
--- a/README.md
+++ b/README.md
@@ -16,7 +16,7 @@ SQLBot 是一款基于大模型和 RAG 的智能问数系统。SQLBot 的优势
## 工作原理
-
+
## 快速开始
@@ -31,8 +31,9 @@ docker run -d \
-p 8000:8000 \
-p 8001:8001 \
-v ./data/sqlbot/excel:/opt/sqlbot/data/excel \
+ -v ./data/sqlbot/file:/opt/sqlbot/data/file \
-v ./data/sqlbot/images:/opt/sqlbot/images \
- -v ./data/sqlbot/logs:/opt/sqlbot/logs \
+ -v ./data/sqlbot/logs:/opt/sqlbot/app/logs \
-v ./data/postgresql:/var/lib/postgresql/data \
--privileged=true \
dataease/sqlbot
@@ -70,6 +71,7 @@ docker run -d \
- [1Panel](https://github.com/1panel-dev/1panel/) - 现代化、开源的 Linux 服务器运维管理面板
- [MaxKB](https://github.com/1panel-dev/MaxKB/) - 强大易用的企业级智能体平台
- [JumpServer](https://github.com/jumpserver/jumpserver/) - 广受欢迎的开源堡垒机
+- [Cordys CRM](https://github.com/1Panel-dev/CordysCRM) - 新一代的开源 AI CRM 系统
- [Halo](https://github.com/halo-dev/halo/) - 强大易用的开源建站工具
- [MeterSphere](https://github.com/metersphere/metersphere/) - 新一代的开源持续测试工具
diff --git a/backend/alembic/env.py b/backend/alembic/env.py
index 3f72963a..79dfbdc6 100755
--- a/backend/alembic/env.py
+++ b/backend/alembic/env.py
@@ -26,8 +26,10 @@
# from apps.settings.models.setting_models import SQLModel
# from apps.chat.models.chat_model import SQLModel
from apps.terminology.models.terminology_model import SQLModel
+# from apps.data_training.models.data_training_model import SQLModel
# from apps.dashboard.models.dashboard_model import SQLModel
from common.core.config import settings # noqa
+#from apps.datasource.models.datasource import SQLModel
target_metadata = SQLModel.metadata
diff --git a/backend/alembic/versions/042_data_training.py b/backend/alembic/versions/042_data_training.py
new file mode 100644
index 00000000..ec44c89f
--- /dev/null
+++ b/backend/alembic/versions/042_data_training.py
@@ -0,0 +1,41 @@
+"""042_data_training
+
+Revision ID: a487d9c69341
+Revises: c4c3c36b720d
+Create Date: 2025-09-15 15:41:43.332771
+
+"""
+from alembic import op
+import sqlalchemy as sa
+import sqlmodel.sql.sqltypes
+from sqlalchemy.dialects import postgresql
+import pgvector
+
+# revision identifiers, used by Alembic.
+revision = 'a487d9c69341'
+down_revision = 'c4c3c36b720d'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('data_training',
+ sa.Column('id', sa.BigInteger(), sa.Identity(always=True), nullable=False),
+ sa.Column('oid', sa.BigInteger(), nullable=True),
+ sa.Column('datasource', sa.BigInteger(), nullable=True),
+ sa.Column('create_time', sa.DateTime(), nullable=True),
+ sa.Column('question', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=True),
+ sa.Column('description', sa.Text(), nullable=True),
+ sa.Column('embedding', pgvector.sqlalchemy.vector.VECTOR(), nullable=True),
+ sa.PrimaryKeyConstraint('id')
+ )
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+
+ op.drop_table('data_training')
+ # ### end Alembic commands ###
diff --git a/backend/alembic/versions/043_modify_ds_id_type.py b/backend/alembic/versions/043_modify_ds_id_type.py
new file mode 100644
index 00000000..4ef303a0
--- /dev/null
+++ b/backend/alembic/versions/043_modify_ds_id_type.py
@@ -0,0 +1,55 @@
+"""043_modify_ds_id_type
+
+Revision ID: dac062c1f7b1
+Revises: a487d9c69341
+Create Date: 2025-09-22 17:20:44.465735
+
+"""
+import sqlalchemy as sa
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = 'dac062c1f7b1'
+down_revision = 'a487d9c69341'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.alter_column('core_datasource', 'id',
+ existing_type=sa.INTEGER(),
+ type_=sa.BigInteger(),
+ existing_nullable=False,
+ autoincrement=True)
+ op.alter_column('core_field', 'id',
+ existing_type=sa.INTEGER(),
+ type_=sa.BigInteger(),
+ existing_nullable=False,
+ autoincrement=True)
+ op.alter_column('core_table', 'id',
+ existing_type=sa.INTEGER(),
+ type_=sa.BigInteger(),
+ existing_nullable=False,
+ autoincrement=True)
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.alter_column('core_table', 'id',
+ existing_type=sa.BigInteger(),
+ type_=sa.INTEGER(),
+ existing_nullable=False,
+ autoincrement=True)
+ op.alter_column('core_field', 'id',
+ existing_type=sa.BigInteger(),
+ type_=sa.INTEGER(),
+ existing_nullable=False,
+ autoincrement=True)
+ op.alter_column('core_datasource', 'id',
+ existing_type=sa.BigInteger(),
+ type_=sa.INTEGER(),
+ existing_nullable=False,
+ autoincrement=True)
+ # ### end Alembic commands ###
diff --git a/backend/alembic/versions/044_table_relation.py b/backend/alembic/versions/044_table_relation.py
new file mode 100644
index 00000000..9f655ed4
--- /dev/null
+++ b/backend/alembic/versions/044_table_relation.py
@@ -0,0 +1,29 @@
+"""044_table_relation
+
+Revision ID: 455b8ce69e80
+Revises: dac062c1f7b1
+Create Date: 2025-09-24 13:34:08.205659
+
+"""
+from alembic import op
+import sqlalchemy as sa
+import sqlmodel.sql.sqltypes
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = '455b8ce69e80'
+down_revision = 'dac062c1f7b1'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column('core_datasource', sa.Column('table_relation', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_column('core_datasource', 'table_relation')
+ # ### end Alembic commands ###
diff --git a/backend/alembic/versions/045_modify_terminolog.py b/backend/alembic/versions/045_modify_terminolog.py
new file mode 100644
index 00000000..a452bb6c
--- /dev/null
+++ b/backend/alembic/versions/045_modify_terminolog.py
@@ -0,0 +1,31 @@
+"""045_modify_terminolog
+
+Revision ID: 45e7e52bf2b8
+Revises: 455b8ce69e80
+Create Date: 2025-09-25 14:49:24.521795
+
+"""
+from alembic import op
+import sqlalchemy as sa
+import sqlmodel.sql.sqltypes
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = '45e7e52bf2b8'
+down_revision = '455b8ce69e80'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column('terminology', sa.Column('specific_ds', sa.Boolean(), nullable=True))
+ op.add_column('terminology', sa.Column('datasource_ids', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_column('terminology', 'datasource_ids')
+ op.drop_column('terminology', 'specific_ds')
+ # ### end Alembic commands ###
diff --git a/backend/apps/api.py b/backend/apps/api.py
index 14b07cea..8b836c0d 100644
--- a/backend/apps/api.py
+++ b/backend/apps/api.py
@@ -1,11 +1,12 @@
from fastapi import APIRouter
-from apps.terminology.api import terminology
from apps.chat.api import chat
from apps.dashboard.api import dashboard_api
-from apps.datasource.api import datasource
-from apps.system.api import login, user, aimodel, workspace, assistant
+from apps.data_training.api import data_training
+from apps.datasource.api import datasource, table_relation
from apps.mcp import mcp
+from apps.system.api import login, user, aimodel, workspace, assistant
+from apps.terminology.api import terminology
api_router = APIRouter()
api_router.include_router(login.router)
@@ -14,9 +15,9 @@
api_router.include_router(assistant.router)
api_router.include_router(aimodel.router)
api_router.include_router(terminology.router)
+api_router.include_router(data_training.router)
api_router.include_router(datasource.router)
api_router.include_router(chat.router)
api_router.include_router(dashboard_api.router)
api_router.include_router(mcp.router)
-
-
+api_router.include_router(table_relation.router)
diff --git a/backend/apps/chat/api/chat.py b/backend/apps/chat/api/chat.py
index fa85d646..a28c03d7 100644
--- a/backend/apps/chat/api/chat.py
+++ b/backend/apps/chat/api/chat.py
@@ -13,7 +13,7 @@
delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data, get_chat_record_by_id
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, ExcelData
from apps.chat.task.llm import LLMService
-from common.core.deps import CurrentAssistant, SessionDep, CurrentUser
+from common.core.deps import CurrentAssistant, SessionDep, CurrentUser, Trans
router = APIRouter(tags=["Data Q&A"], prefix="/chat")
@@ -105,14 +105,15 @@ async def start_chat(session: SessionDep, current_user: CurrentUser):
@router.post("/recommend_questions/{chat_record_id}")
async def recommend_questions(session: SessionDep, current_user: CurrentUser, chat_record_id: int,
current_assistant: CurrentAssistant):
+ def _return_empty():
+ yield 'data:' + orjson.dumps({'content': '[]', 'type': 'recommended_question'}).decode() + '\n\n'
+
try:
record = get_chat_record_by_id(session, chat_record_id)
if not record:
- raise HTTPException(
- status_code=400,
- detail=f"Chat record with id {chat_record_id} not found"
- )
+ return StreamingResponse(_return_empty(), media_type="text/event-stream")
+
request_question = ChatQuestion(chat_id=record.chat_id, question=record.question if record.question else '')
llm_service = await LLMService.create(current_user, request_question, current_assistant, True)
@@ -120,10 +121,11 @@ async def recommend_questions(session: SessionDep, current_user: CurrentUser, ch
llm_service.run_recommend_questions_task_async()
except Exception as e:
traceback.print_exc()
- raise HTTPException(
- status_code=500,
- detail=str(e)
- )
+
+ def _err(_e: Exception):
+ yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n'
+
+ return StreamingResponse(_err(e), media_type="text/event-stream")
return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")
@@ -143,7 +145,7 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que
"""
try:
- llm_service = await LLMService.create(current_user, request_question, current_assistant)
+ llm_service = await LLMService.create(current_user, request_question, current_assistant, embedding=True)
llm_service.init_record()
llm_service.run_task_async()
except Exception as e:
@@ -199,10 +201,16 @@ def _err(_e: Exception):
@router.post("/excel/export")
-async def export_excel(excel_data: ExcelData):
+async def export_excel(excel_data: ExcelData, trans: Trans):
def inner():
_fields_list = []
data = []
+ if not excel_data.data:
+ raise HTTPException(
+ status_code=500,
+ detail=trans("i18n_excel_export.data_is_empty")
+ )
+
for _data in excel_data.data:
_row = []
for field in excel_data.axis:
diff --git a/backend/apps/chat/models/chat_model.py b/backend/apps/chat/models/chat_model.py
index eb418251..6cc30b83 100644
--- a/backend/apps/chat/models/chat_model.py
+++ b/backend/apps/chat/models/chat_model.py
@@ -40,6 +40,10 @@ class OperationEnum(Enum):
CHOOSE_DATASOURCE = '6'
GENERATE_DYNAMIC_SQL = '7'
+class ChatFinishStep(Enum):
+ GENERATE_SQL = 1
+ QUERY_DATA = 2
+ GENERATE_CHART = 3
# TODO choose table / check connection / generate description
@@ -136,7 +140,7 @@ class CreateChat(BaseModel):
id: int = None
question: str = None
datasource: int = None
- origin: Optional[int] = 0
+ origin: Optional[int] = 0 # 0是页面上,mcp是1,小助手是2
class RenameChat(BaseModel):
@@ -172,11 +176,13 @@ class AiModelQuestion(BaseModel):
filter: str = []
sub_query: Optional[list[dict]] = None
terminologies: str = ""
+ data_training: str = ""
error_msg: str = ""
def sql_sys_question(self):
return get_sql_template()['system'].format(engine=self.engine, schema=self.db_schema, question=self.question,
- lang=self.lang, terminologies=self.terminologies)
+ lang=self.lang, terminologies=self.terminologies,
+ data_training=self.data_training)
def sql_user_question(self, current_time: str):
return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question,
@@ -244,6 +250,7 @@ class McpQuestion(BaseModel):
question: str = Body(description='用户提问')
chat_id: int = Body(description='会话ID')
token: str = Body(description='token')
+ stream: Optional[bool] = Body(description='是否流式输出,默认为true开启, 关闭false则返回JSON对象', default=True)
class AxisObj(BaseModel):
@@ -256,3 +263,10 @@ class ExcelData(BaseModel):
axis: list[AxisObj] = []
data: list[dict] = []
name: str = 'Excel'
+
+
+class McpAssistant(BaseModel):
+ question: str = Body(description='用户提问')
+ url: str = Body(description='第三方数据接口')
+ authorization: str = Body(description='第三方接口凭证')
+ stream: Optional[bool] = Body(description='是否流式输出,默认为true开启, 关闭false则返回JSON对象', default=True)
diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py
index 8cc1f470..a30a741d 100644
--- a/backend/apps/chat/task/llm.py
+++ b/backend/apps/chat/task/llm.py
@@ -6,7 +6,7 @@
import warnings
from concurrent.futures import ThreadPoolExecutor, Future
from datetime import datetime
-from typing import Any, List, Optional, Union, Dict
+from typing import Any, List, Optional, Union, Dict, Iterator
import numpy as np
import orjson
@@ -16,7 +16,7 @@
from langchain.chat_models.base import BaseChatModel
from langchain_community.utilities import SQLDatabase
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage, BaseMessageChunk
-from sqlalchemy import select
+from sqlalchemy import and_, select
from sqlalchemy.orm import sessionmaker
from sqlmodel import Session
@@ -28,9 +28,12 @@
get_old_questions, save_analysis_predict_record, rename_chat, get_chart_config, \
get_chat_chart_data, list_generate_sql_logs, list_generate_chart_logs, start_log, end_log, \
get_last_execute_sql_error
-from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum
+from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \
+ ChatFinishStep
+from apps.data_training.curd.data_training import get_training_template
from apps.datasource.crud.datasource import get_table_schema
from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user
+from apps.datasource.embedding.ds_embedding import get_ds_embedding
from apps.datasource.models.datasource import CoreDatasource
from apps.db.db import exec_sql, get_version, check_connection
from apps.system.crud.assistant import AssistantOutDs, AssistantOutDsFactory, get_assistant_ds
@@ -82,7 +85,7 @@ class LLMService:
def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion,
current_assistant: Optional[CurrentAssistant] = None, no_reasoning: bool = False,
- config: LLMConfig = None):
+ embedding: bool = False, config: LLMConfig = None):
self.chunk_list = []
# engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))
# session_maker = sessionmaker(bind=engine)
@@ -111,7 +114,8 @@ def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion,
if not ds:
raise SingleMessageError("No available datasource configuration found")
chat_question.engine = (ds.type_name if ds.type != 'excel' else 'PostgreSQL') + get_version(ds)
- chat_question.db_schema = get_table_schema(session=self.session, current_user=current_user, ds=ds)
+ chat_question.db_schema = get_table_schema(session=self.session, current_user=current_user, ds=ds,
+ question=chat_question.question, embedding=embedding)
self.generate_sql_logs = list_generate_sql_logs(session=self.session, chart_id=chat_id)
self.generate_chart_logs = list_generate_chart_logs(session=self.session, chart_id=chat_id)
@@ -146,8 +150,6 @@ def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion,
else:
self.chat_question.error_msg = ''
- self.init_messages()
-
@classmethod
async def create(cls, *args, **kwargs):
config: LLMConfig = await get_default_config()
@@ -239,8 +241,9 @@ def generate_analysis(self):
self.chat_question.data = orjson.dumps(data.get('data')).decode()
analysis_msg: List[Union[BaseMessage, dict[str, Any]]] = []
+ ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None
self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question,
- self.current_user.oid)
+ self.current_user.oid, ds_id)
analysis_msg.append(SystemMessage(content=self.chat_question.analysis_sys_question()))
analysis_msg.append(HumanMessage(content=self.chat_question.analysis_user_question()))
@@ -257,22 +260,14 @@ def generate_analysis(self):
in analysis_msg])
full_thinking_text = ''
full_analysis_text = ''
- res = self.llm.stream(analysis_msg)
token_usage = {}
+ res = process_stream(self.llm.stream(analysis_msg), token_usage)
for chunk in res:
- SQLBotLogUtil.info(chunk)
- reasoning_content_chunk = ''
- if 'reasoning_content' in chunk.additional_kwargs:
- reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
- # else:
- # reasoning_content_chunk = chunk.get('reasoning_content')
- if reasoning_content_chunk is None:
- reasoning_content_chunk = ''
- full_thinking_text += reasoning_content_chunk
-
- full_analysis_text += chunk.content
- yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
- get_token_usage(chunk, token_usage)
+ if chunk.get('content'):
+ full_analysis_text += chunk.get('content')
+ if chunk.get('reasoning_content'):
+ full_thinking_text += chunk.get('reasoning_content')
+ yield chunk
analysis_msg.append(AIMessage(full_analysis_text))
@@ -309,22 +304,14 @@ def generate_predict(self):
in predict_msg])
full_thinking_text = ''
full_predict_text = ''
- res = self.llm.stream(predict_msg)
token_usage = {}
+ res = process_stream(self.llm.stream(predict_msg), token_usage)
for chunk in res:
- SQLBotLogUtil.info(chunk)
- reasoning_content_chunk = ''
- if 'reasoning_content' in chunk.additional_kwargs:
- reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
- # else:
- # reasoning_content_chunk = chunk.get('reasoning_content')
- if reasoning_content_chunk is None:
- reasoning_content_chunk = ''
- full_thinking_text += reasoning_content_chunk
-
- full_predict_text += chunk.content
- yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
- get_token_usage(chunk, token_usage)
+ if chunk.get('content'):
+ full_predict_text += chunk.get('content')
+ if chunk.get('reasoning_content'):
+ full_thinking_text += chunk.get('reasoning_content')
+ yield chunk
predict_msg.append(AIMessage(full_predict_text))
self.record = save_predict_answer(session=self.session, record_id=self.record.id,
@@ -345,7 +332,9 @@ def generate_recommend_questions_task(self):
if self.ds and not self.chat_question.db_schema:
self.chat_question.db_schema = self.out_ds_instance.get_db_schema(
self.ds.id) if self.out_ds_instance else get_table_schema(session=self.session,
- current_user=self.current_user, ds=self.ds)
+ current_user=self.current_user, ds=self.ds,
+ question=self.chat_question.question,
+ embedding=False)
guess_msg: List[Union[BaseMessage, dict[str, Any]]] = []
guess_msg.append(SystemMessage(content=self.chat_question.guess_sys_question()))
@@ -367,21 +356,13 @@ def generate_recommend_questions_task(self):
full_thinking_text = ''
full_guess_text = ''
token_usage = {}
- res = self.llm.stream(guess_msg)
+ res = process_stream(self.llm.stream(guess_msg), token_usage)
for chunk in res:
- SQLBotLogUtil.info(chunk)
- reasoning_content_chunk = ''
- if 'reasoning_content' in chunk.additional_kwargs:
- reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
- # else:
- # reasoning_content_chunk = chunk.get('reasoning_content')
- if reasoning_content_chunk is None:
- reasoning_content_chunk = ''
- full_thinking_text += reasoning_content_chunk
-
- full_guess_text += chunk.content
- yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
- get_token_usage(chunk, token_usage)
+ if chunk.get('content'):
+ full_guess_text += chunk.get('content')
+ if chunk.get('reasoning_content'):
+ full_thinking_text += chunk.get('reasoning_content')
+ yield chunk
guess_msg.append(AIMessage(full_guess_text))
@@ -405,9 +386,8 @@ def select_datasource(self):
if self.current_assistant and self.current_assistant.type != 4:
_ds_list = get_assistant_ds(session=self.session, llm_service=self)
else:
- oid: str = self.current_user.oid
stmt = select(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description).where(
- CoreDatasource.oid == oid)
+ and_(CoreDatasource.oid == self.current_user.oid))
_ds_list = [
{
"id": ds.id,
@@ -425,57 +405,58 @@ def select_datasource(self):
full_thinking_text = ''
full_text = ''
-
if not ignore_auto_select:
- _ds_list_dict = []
- for _ds in _ds_list:
- _ds_list_dict.append(_ds)
- datasource_msg.append(
- HumanMessage(self.chat_question.datasource_user_question(orjson.dumps(_ds_list_dict).decode())))
-
- self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = start_log(session=self.session,
- ai_modal_id=self.chat_question.ai_modal_id,
- ai_modal_name=self.chat_question.ai_modal_name,
- operate=OperationEnum.CHOOSE_DATASOURCE,
- record_id=self.record.id,
- full_message=[{'type': msg.type,
- 'content': msg.content} for
- msg in datasource_msg])
-
- token_usage = {}
- res = self.llm.stream(datasource_msg)
- for chunk in res:
- SQLBotLogUtil.info(chunk)
- reasoning_content_chunk = ''
- if 'reasoning_content' in chunk.additional_kwargs:
- reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
- # else:
- # reasoning_content_chunk = chunk.get('reasoning_content')
- if reasoning_content_chunk is None:
- reasoning_content_chunk = ''
- full_thinking_text += reasoning_content_chunk
-
- full_text += chunk.content
- yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
- get_token_usage(chunk, token_usage)
- datasource_msg.append(AIMessage(full_text))
-
- self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=self.session,
- log=self.current_logs[
- OperationEnum.CHOOSE_DATASOURCE],
- full_message=[
- {'type': msg.type, 'content': msg.content}
- for msg in datasource_msg],
- reasoning_content=full_thinking_text,
- token_usage=token_usage)
-
- json_str = extract_nested_json(full_text)
+ if settings.TABLE_EMBEDDING_ENABLED:
+ ds = get_ds_embedding(self.session, self.current_user, _ds_list, self.out_ds_instance,
+ self.chat_question.question, self.current_assistant)
+ yield {'content': '{"id":' + str(ds.get('id')) + '}'}
+ else:
+ _ds_list_dict = []
+ for _ds in _ds_list:
+ _ds_list_dict.append(_ds)
+ datasource_msg.append(
+ HumanMessage(self.chat_question.datasource_user_question(orjson.dumps(_ds_list_dict).decode())))
+
+ self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = start_log(session=self.session,
+ ai_modal_id=self.chat_question.ai_modal_id,
+ ai_modal_name=self.chat_question.ai_modal_name,
+ operate=OperationEnum.CHOOSE_DATASOURCE,
+ record_id=self.record.id,
+ full_message=[{'type': msg.type,
+ 'content': msg.content}
+ for
+ msg in datasource_msg])
+
+ token_usage = {}
+ res = process_stream(self.llm.stream(datasource_msg), token_usage)
+ for chunk in res:
+ if chunk.get('content'):
+ full_text += chunk.get('content')
+ if chunk.get('reasoning_content'):
+ full_thinking_text += chunk.get('reasoning_content')
+ yield chunk
+ datasource_msg.append(AIMessage(full_text))
+
+ self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=self.session,
+ log=self.current_logs[
+ OperationEnum.CHOOSE_DATASOURCE],
+ full_message=[
+ {'type': msg.type,
+ 'content': msg.content}
+ for msg in datasource_msg],
+ reasoning_content=full_thinking_text,
+ token_usage=token_usage)
+
+ json_str = extract_nested_json(full_text)
+ if json_str is None:
+ raise SingleMessageError(f'Cannot parse datasource from answer: {full_text}')
+ ds = orjson.loads(json_str)
_error: Exception | None = None
_datasource: int | None = None
_engine_type: str | None = None
try:
- data: dict = _ds_list[0] if ignore_auto_select else orjson.loads(json_str)
+ data: dict = _ds_list[0] if ignore_auto_select else ds
if data.get('id') and data.get('id') != 0:
_datasource = data['id']
@@ -497,7 +478,8 @@ def select_datasource(self):
self.chat_question.engine = (_ds.type_name if _ds.type != 'excel' else 'PostgreSQL') + get_version(
self.ds)
self.chat_question.db_schema = get_table_schema(session=self.session,
- current_user=self.current_user, ds=self.ds)
+ current_user=self.current_user, ds=self.ds,
+ question=self.chat_question.question)
_engine_type = self.chat_question.engine
_chat.engine_type = _ds.type_name
# save chat
@@ -514,16 +496,21 @@ def select_datasource(self):
except Exception as e:
_error = e
- if not ignore_auto_select:
+ if not ignore_auto_select and not settings.TABLE_EMBEDDING_ENABLED:
self.record = save_select_datasource_answer(session=self.session, record_id=self.record.id,
answer=orjson.dumps({'content': full_text}).decode(),
datasource=_datasource,
engine_type=_engine_type)
+ if self.ds:
+ oid = self.ds.oid if isinstance(self.ds, CoreDatasource) else 1
+ ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None
- self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question,
- self.ds.oid if isinstance(self.ds,
- CoreDatasource) else 1)
- self.init_messages()
+ self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question, oid,
+ ds_id)
+ self.chat_question.data_training = get_training_template(self.session, self.chat_question.question, ds_id,
+ oid)
+
+ self.init_messages()
if _error:
raise _error
@@ -544,21 +531,13 @@ def generate_sql(self):
full_thinking_text = ''
full_sql_text = ''
token_usage = {}
- res = self.llm.stream(self.sql_message)
+ res = process_stream(self.llm.stream(self.sql_message), token_usage)
for chunk in res:
- SQLBotLogUtil.info(chunk)
- reasoning_content_chunk = ''
- if 'reasoning_content' in chunk.additional_kwargs:
- reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
- # else:
- # reasoning_content_chunk = chunk.get('reasoning_content')
- if reasoning_content_chunk is None:
- reasoning_content_chunk = ''
- full_thinking_text += reasoning_content_chunk
-
- full_sql_text += chunk.content
- yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
- get_token_usage(chunk, token_usage)
+ if chunk.get('content'):
+ full_sql_text += chunk.get('content')
+ if chunk.get('reasoning_content'):
+ full_thinking_text += chunk.get('reasoning_content')
+ yield chunk
self.sql_message.append(AIMessage(full_sql_text))
@@ -591,18 +570,14 @@ def generate_with_sub_sql(self, sql, sub_mappings: list):
full_thinking_text = ''
full_dynamic_text = ''
- res = self.llm.stream(dynamic_sql_msg)
token_usage = {}
+ res = process_stream(self.llm.stream(dynamic_sql_msg), token_usage)
for chunk in res:
- SQLBotLogUtil.info(chunk)
- reasoning_content_chunk = ''
- if 'reasoning_content' in chunk.additional_kwargs:
- reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
- if reasoning_content_chunk is None:
- reasoning_content_chunk = ''
- full_thinking_text += reasoning_content_chunk
- full_dynamic_text += chunk.content
- get_token_usage(chunk, token_usage)
+ if chunk.get('content'):
+ full_dynamic_text += chunk.get('content')
+ if chunk.get('reasoning_content'):
+ full_thinking_text += chunk.get('reasoning_content')
+ yield chunk
dynamic_sql_msg.append(AIMessage(full_dynamic_text))
@@ -654,22 +629,13 @@ def build_table_filter(self, sql: str, filters: list):
in permission_sql_msg])
full_thinking_text = ''
full_filter_text = ''
- res = self.llm.stream(permission_sql_msg)
token_usage = {}
+ res = process_stream(self.llm.stream(permission_sql_msg), token_usage)
for chunk in res:
- SQLBotLogUtil.info(chunk)
- reasoning_content_chunk = ''
- if 'reasoning_content' in chunk.additional_kwargs:
- reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
- # else:
- # reasoning_content_chunk = chunk.get('reasoning_content')
- if reasoning_content_chunk is None:
- reasoning_content_chunk = ''
- full_thinking_text += reasoning_content_chunk
-
- full_filter_text += chunk.content
- # yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
- get_token_usage(chunk, token_usage)
+ if chunk.get('content'):
+ full_filter_text += chunk.get('content')
+ if chunk.get('reasoning_content'):
+ full_thinking_text += chunk.get('reasoning_content')
permission_sql_msg.append(AIMessage(full_filter_text))
@@ -719,21 +685,13 @@ def generate_chart(self, chart_type: Optional[str] = ''):
full_thinking_text = ''
full_chart_text = ''
token_usage = {}
- res = self.llm.stream(self.chart_message)
+ res = process_stream(self.llm.stream(self.chart_message), token_usage)
for chunk in res:
- SQLBotLogUtil.info(chunk)
- reasoning_content_chunk = ''
- if 'reasoning_content' in chunk.additional_kwargs:
- reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
- # else:
- # reasoning_content_chunk = chunk.get('reasoning_content')
- if reasoning_content_chunk is None:
- reasoning_content_chunk = ''
- full_thinking_text += reasoning_content_chunk
-
- full_chart_text += chunk.content
- yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
- get_token_usage(chunk, token_usage)
+ if chunk.get('content'):
+ full_chart_text += chunk.get('content')
+ if chunk.get('reasoning_content'):
+ full_thinking_text += chunk.get('reasoning_content')
+ yield chunk
self.chart_message.append(AIMessage(full_chart_text))
@@ -881,7 +839,7 @@ def save_sql_data(self, data_obj: Dict[str, Any]):
def finish(self):
return finish_record(session=self.session, record_id=self.record.id)
- def execute_sql(self, sql: str, tables):
+ def execute_sql(self, sql: str):
"""Execute SQL query
Args:
@@ -893,7 +851,7 @@ def execute_sql(self, sql: str, tables):
"""
SQLBotLogUtil.info(f"Executing SQL on ds_id {self.ds.id}: {sql}")
try:
- return exec_sql(ds=self.ds, sql=sql, origin_column=False, table_name=tables)
+ return exec_sql(ds=self.ds, sql=sql, origin_column=False)
except Exception as e:
if isinstance(e, ParseSQLResultError):
raise e
@@ -922,24 +880,36 @@ def await_result(self):
break
yield chunk
- def run_task_async(self, in_chat: bool = True):
- self.future = executor.submit(self.run_task_cache, in_chat)
+ def run_task_async(self, in_chat: bool = True, stream: bool = True,
+ finish_step: ChatFinishStep = ChatFinishStep.GENERATE_CHART):
+ if in_chat:
+ stream = True
+ self.future = executor.submit(self.run_task_cache, in_chat, stream, finish_step)
- def run_task_cache(self, in_chat: bool = True):
- for chunk in self.run_task(in_chat):
+ def run_task_cache(self, in_chat: bool = True, stream: bool = True,
+ finish_step: ChatFinishStep = ChatFinishStep.GENERATE_CHART):
+ for chunk in self.run_task(in_chat, stream, finish_step):
self.chunk_list.append(chunk)
- def run_task(self, in_chat: bool = True):
+ def run_task(self, in_chat: bool = True, stream: bool = True,
+ finish_step: ChatFinishStep = ChatFinishStep.GENERATE_CHART):
+ json_result: Dict[str, Any] = {'success': True}
try:
if self.ds:
+ oid = self.ds.oid if isinstance(self.ds, CoreDatasource) else 1
+ ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None
self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question,
- self.ds.oid if isinstance(self.ds,
- CoreDatasource) else 1)
+ oid, ds_id)
+ self.chat_question.data_training = get_training_template(self.session, self.chat_question.question,
+ ds_id, oid)
+
self.init_messages()
# return id
if in_chat:
yield 'data:' + orjson.dumps({'type': 'id', 'id': self.get_record().id}).decode() + '\n\n'
+ if not stream:
+ json_result['record_id'] = self.get_record().id
# return title
if self.change_title:
@@ -949,8 +919,10 @@ def run_task(self, in_chat: bool = True):
brief=self.chat_question.question.strip()[:20]))
if in_chat:
yield 'data:' + orjson.dumps({'type': 'brief', 'brief': brief}).decode() + '\n\n'
+ if not stream:
+ json_result['title'] = brief
- # select datasource if datasource is none
+ # select datasource if datasource is none
if not self.ds:
ds_res = self.select_datasource()
@@ -968,7 +940,8 @@ def run_task(self, in_chat: bool = True):
self.chat_question.db_schema = self.out_ds_instance.get_db_schema(
self.ds.id) if self.out_ds_instance else get_table_schema(session=self.session,
current_user=self.current_user,
- ds=self.ds)
+ ds=self.ds,
+ question=self.chat_question.question)
else:
self.validate_history_ds()
@@ -988,7 +961,6 @@ def run_task(self, in_chat: bool = True):
'type': 'sql-result'}).decode() + '\n\n'
if in_chat:
yield 'data:' + orjson.dumps({'type': 'info', 'msg': 'sql generated'}).decode() + '\n\n'
-
# filter sql
SQLBotLogUtil.info(full_sql_text)
@@ -1022,14 +994,18 @@ def run_task(self, in_chat: bool = True):
sql = self.check_save_sql(res=full_sql_text)
else:
sql = self.check_save_sql(res=full_sql_text)
- tables = []
- SQLBotLogUtil.info(sql)
+ SQLBotLogUtil.info('sql: ' + sql)
+
+ if not stream:
+ json_result['sql'] = sql
+
format_sql = sqlparse.format(sql, reindent=True)
if in_chat:
yield 'data:' + orjson.dumps({'content': format_sql, 'type': 'sql'}).decode() + '\n\n'
else:
- yield f'```sql\n{format_sql}\n```\n\n'
+ if stream:
+ yield f'```sql\n{format_sql}\n```\n\n'
# execute sql
real_execute_sql = sql
@@ -1040,11 +1016,46 @@ def run_task(self, in_chat: bool = True):
subsql)
real_execute_sql = assistant_dynamic_sql
- result = self.execute_sql(sql=real_execute_sql, tables=tables)
- print(result)
+ if finish_step.value <= ChatFinishStep.GENERATE_SQL.value:
+ if in_chat:
+ yield 'data:' + orjson.dumps({'type': 'finish'}).decode() + '\n\n'
+ if not stream:
+ yield json_result
+ return
+
+ result = self.execute_sql(sql=real_execute_sql)
self.save_sql_data(data_obj=result)
if in_chat:
yield 'data:' + orjson.dumps({'content': 'execute-success', 'type': 'sql-data'}).decode() + '\n\n'
+ if not stream:
+ json_result['data'] = result.get('data')
+
+ if finish_step.value <= ChatFinishStep.QUERY_DATA.value:
+ if stream:
+ if in_chat:
+ yield 'data:' + orjson.dumps({'type': 'finish'}).decode() + '\n\n'
+ else:
+ data = []
+ _fields_list = []
+ _fields_skip = False
+ for _data in result.get('data'):
+ _row = []
+ for field in result.get('fields'):
+ _row.append(_data.get(field))
+ if not _fields_skip:
+ _fields_list.append(field)
+ data.append(_row)
+ _fields_skip = True
+
+ if not data or not _fields_list:
+ yield 'The SQL execution result is empty.\n\n'
+ else:
+ df = pd.DataFrame(np.array(data), columns=_fields_list)
+ markdown_table = df.to_markdown(index=False)
+ yield markdown_table + '\n\n'
+ else:
+ yield json_result
+ return
# generate chart
chart_res = self.generate_chart(chart_type)
@@ -1062,39 +1073,47 @@ def run_task(self, in_chat: bool = True):
SQLBotLogUtil.info(full_chart_text)
chart = self.check_save_chart(res=full_chart_text)
SQLBotLogUtil.info(chart)
+
+ if not stream:
+ json_result['chart'] = chart
+
if in_chat:
yield 'data:' + orjson.dumps(
{'content': orjson.dumps(chart).decode(), 'type': 'chart'}).decode() + '\n\n'
else:
- data = []
- _fields = {}
- if chart.get('columns'):
- for _column in chart.get('columns'):
- if _column:
- _fields[_column.get('value')] = _column.get('name')
- if chart.get('axis'):
- if chart.get('axis').get('x'):
- _fields[chart.get('axis').get('x').get('value')] = chart.get('axis').get('x').get('name')
- if chart.get('axis').get('y'):
- _fields[chart.get('axis').get('y').get('value')] = chart.get('axis').get('y').get('name')
- if chart.get('axis').get('series'):
- _fields[chart.get('axis').get('series').get('value')] = chart.get('axis').get('series').get(
- 'name')
- _fields_list = []
- _fields_skip = False
- for _data in result.get('data'):
- _row = []
- for field in result.get('fields'):
- _row.append(_data.get(field))
- if not _fields_skip:
- _fields_list.append(field if not _fields.get(field) else _fields.get(field))
- data.append(_row)
- _fields_skip = True
- df = pd.DataFrame(np.array(data), columns=_fields_list)
- markdown_table = df.to_markdown(index=False)
- yield markdown_table + '\n\n'
-
- record = self.finish()
+ if stream:
+ data = []
+ _fields = {}
+ if chart.get('columns'):
+ for _column in chart.get('columns'):
+ if _column:
+ _fields[_column.get('value')] = _column.get('name')
+ if chart.get('axis'):
+ if chart.get('axis').get('x'):
+ _fields[chart.get('axis').get('x').get('value')] = chart.get('axis').get('x').get('name')
+ if chart.get('axis').get('y'):
+ _fields[chart.get('axis').get('y').get('value')] = chart.get('axis').get('y').get('name')
+ if chart.get('axis').get('series'):
+ _fields[chart.get('axis').get('series').get('value')] = chart.get('axis').get('series').get(
+ 'name')
+ _fields_list = []
+ _fields_skip = False
+ for _data in result.get('data'):
+ _row = []
+ for field in result.get('fields'):
+ _row.append(_data.get(field))
+ if not _fields_skip:
+ _fields_list.append(field if not _fields.get(field) else _fields.get(field))
+ data.append(_row)
+ _fields_skip = True
+
+ if not data or not _fields_list:
+ yield 'The SQL execution result is empty.\n\n'
+ else:
+ df = pd.DataFrame(np.array(data), columns=_fields_list)
+ markdown_table = df.to_markdown(index=False)
+ yield markdown_table + '\n\n'
+
if in_chat:
yield 'data:' + orjson.dumps({'type': 'finish'}).decode() + '\n\n'
else:
@@ -1103,8 +1122,16 @@ def run_task(self, in_chat: bool = True):
yield '### generated chart picture\n\n'
image_url = request_picture(self.record.chat_id, self.record.id, chart, result)
SQLBotLogUtil.info(image_url)
- yield f'![{chart["type"]}]({image_url})'
+ if stream:
+ yield f'![{chart["type"]}]({image_url})'
+ else:
+ json_result['image_url'] = image_url
+
+ if not stream:
+ yield json_result
+
except Exception as e:
+ traceback.print_exc()
error_msg: str
if isinstance(e, SingleMessageError):
error_msg = str(e)
@@ -1120,7 +1147,14 @@ def run_task(self, in_chat: bool = True):
if in_chat:
yield 'data:' + orjson.dumps({'content': error_msg, 'type': 'error'}).decode() + '\n\n'
else:
- yield f'> ❌ **ERROR**\n\n> \n\n> {error_msg}。'
+ if stream:
+ yield f'> ❌ **ERROR**\n\n> \n\n> {error_msg}。'
+ else:
+ json_result['success'] = False
+ json_result['message'] = error_msg
+ yield json_result
+ finally:
+ self.finish()
def run_recommend_questions_task_async(self):
self.future = executor.submit(self.run_recommend_questions_task_cache)
@@ -1280,9 +1314,11 @@ def request_picture(chat_id: int, record_id: int, chart: dict, data: dict):
return request_path
-def get_token_usage(chunk: BaseMessageChunk, token_usage: dict = {}):
+def get_token_usage(chunk: BaseMessageChunk, token_usage: dict = None):
try:
if chunk.usage_metadata:
+ if token_usage is None:
+ token_usage = {}
token_usage['input_tokens'] = chunk.usage_metadata.get('input_tokens')
token_usage['output_tokens'] = chunk.usage_metadata.get('output_tokens')
token_usage['total_tokens'] = chunk.usage_metadata.get('total_tokens')
@@ -1290,6 +1326,104 @@ def get_token_usage(chunk: BaseMessageChunk, token_usage: dict = {}):
pass
+def process_stream(res: Iterator[BaseMessageChunk],
+ token_usage: Dict[str, Any] = None,
+ enable_tag_parsing: bool = settings.PARSE_REASONING_BLOCK_ENABLED,
+ start_tag: str = settings.DEFAULT_REASONING_CONTENT_START,
+ end_tag: str = settings.DEFAULT_REASONING_CONTENT_END
+ ):
+ if token_usage is None:
+ token_usage = {}
+ in_thinking_block = False # 标记是否在思考过程块中
+ current_thinking = '' # 当前收集的思考过程内容
+ pending_start_tag = '' # 用于缓存可能被截断的开始标签部分
+
+ for chunk in res:
+ SQLBotLogUtil.info(chunk)
+ reasoning_content_chunk = ''
+ content = chunk.content
+ output_content = '' # 实际要输出的内容
+
+ # 检查additional_kwargs中的reasoning_content
+ if 'reasoning_content' in chunk.additional_kwargs:
+ reasoning_content = chunk.additional_kwargs.get('reasoning_content', '')
+ if reasoning_content is None:
+ reasoning_content = ''
+
+ # 累积additional_kwargs中的思考内容到current_thinking
+ current_thinking += reasoning_content
+ reasoning_content_chunk = reasoning_content
+
+ # 只有当current_thinking不是空字符串时才跳过标签解析
+ if not in_thinking_block and current_thinking.strip() != '':
+ output_content = content # 正常输出content
+ yield {
+ 'content': output_content,
+ 'reasoning_content': reasoning_content_chunk
+ }
+ get_token_usage(chunk, token_usage)
+ continue # 跳过后续的标签解析逻辑
+
+ # 如果没有有效的思考内容,并且启用了标签解析,才执行标签解析逻辑
+ # 如果有缓存的开始标签部分,先拼接当前内容
+ if pending_start_tag:
+ content = pending_start_tag + content
+ pending_start_tag = ''
+
+ # 检查是否开始思考过程块(处理可能被截断的开始标签)
+ if enable_tag_parsing and not in_thinking_block and start_tag:
+ if start_tag in content:
+ start_idx = content.index(start_tag)
+ # 只有当开始标签前面没有其他文本时才认为是真正的思考块开始
+ if start_idx == 0 or content[:start_idx].strip() == '':
+ # 完整标签存在且前面没有其他文本
+ output_content += content[:start_idx] # 输出开始标签之前的内容
+ content = content[start_idx + len(start_tag):] # 移除开始标签
+ in_thinking_block = True
+ else:
+ # 开始标签前面有其他文本,不认为是思考块开始
+ output_content += content
+ content = ''
+ else:
+ # 检查是否可能有部分开始标签
+ for i in range(1, len(start_tag)):
+ if content.endswith(start_tag[:i]):
+ # 只有当当前内容全是空白时才缓存部分标签
+ if content[:-i].strip() == '':
+ pending_start_tag = start_tag[:i]
+ content = content[:-i] # 移除可能的部分标签
+ output_content += content
+ content = ''
+ break
+
+ # 处理思考块内容
+ if enable_tag_parsing and in_thinking_block and end_tag:
+ if end_tag in content:
+ # 找到结束标签
+ end_idx = content.index(end_tag)
+ current_thinking += content[:end_idx] # 收集思考内容
+ reasoning_content_chunk += current_thinking # 添加到当前块的思考内容
+ content = content[end_idx + len(end_tag):] # 移除结束标签后的内容
+ current_thinking = '' # 重置当前思考内容
+ in_thinking_block = False
+ output_content += content # 输出结束标签之后的内容
+ else:
+ # 在遇到结束标签前,持续收集思考内容
+ current_thinking += content
+ reasoning_content_chunk += content
+ content = ''
+
+ else:
+ # 不在思考块中或标签解析未启用,正常输出
+ output_content += content
+
+ yield {
+ 'content': output_content,
+ 'reasoning_content': reasoning_content_chunk
+ }
+ get_token_usage(chunk, token_usage)
+
+
def get_lang_name(lang: str):
if lang and lang == 'en':
return '英文'
diff --git a/backend/apps/data_training/__init__.py b/backend/apps/data_training/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/backend/apps/data_training/api/__init__.py b/backend/apps/data_training/api/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/backend/apps/data_training/api/data_training.py b/backend/apps/data_training/api/data_training.py
new file mode 100644
index 00000000..25422c2b
--- /dev/null
+++ b/backend/apps/data_training/api/data_training.py
@@ -0,0 +1,39 @@
+from typing import Optional
+
+from fastapi import APIRouter, Query
+
+from apps.data_training.curd.data_training import page_data_training, create_training, update_training, delete_training
+from apps.data_training.models.data_training_model import DataTrainingInfo
+from common.core.deps import SessionDep, CurrentUser, Trans
+
+router = APIRouter(tags=["DataTraining"], prefix="/system/data-training")
+
+
+@router.get("/page/{current_page}/{page_size}")
+async def pager(session: SessionDep, current_user: CurrentUser, current_page: int, page_size: int,
+ question: Optional[str] = Query(None, description="搜索问题(可选)")):
+ current_page, page_size, total_count, total_pages, _list = page_data_training(session, current_page, page_size,
+ question,
+ current_user.oid)
+
+ return {
+ "current_page": current_page,
+ "page_size": page_size,
+ "total_count": total_count,
+ "total_pages": total_pages,
+ "data": _list
+ }
+
+
+@router.put("")
+async def create_or_update(session: SessionDep, current_user: CurrentUser, trans: Trans, info: DataTrainingInfo):
+ oid = current_user.oid
+ if info.id:
+ return update_training(session, info, oid, trans)
+ else:
+ return create_training(session, info, oid, trans)
+
+
+@router.delete("")
+async def delete(session: SessionDep, id_list: list[int]):
+ delete_training(session, id_list)
diff --git a/backend/apps/data_training/curd/__init__.py b/backend/apps/data_training/curd/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/backend/apps/data_training/curd/data_training.py b/backend/apps/data_training/curd/data_training.py
new file mode 100644
index 00000000..15260a9a
--- /dev/null
+++ b/backend/apps/data_training/curd/data_training.py
@@ -0,0 +1,317 @@
+import datetime
+import logging
+import traceback
+from typing import List, Optional
+from xml.dom.minidom import parseString
+
+import dicttoxml
+from sqlalchemy import and_, select, func, delete, update, or_
+from sqlalchemy import text
+from sqlalchemy.orm.session import Session
+
+from apps.ai_model.embedding import EmbeddingModelCache
+from apps.data_training.models.data_training_model import DataTrainingInfo, DataTraining
+from apps.datasource.models.datasource import CoreDatasource
+from apps.template.generate_chart.generator import get_base_data_training_template
+from common.core.config import settings
+from common.core.deps import SessionDep, Trans
+from common.utils.embedding_threads import run_save_data_training_embeddings
+
+
+def page_data_training(session: SessionDep, current_page: int = 1, page_size: int = 10, name: Optional[str] = None,
+ oid: Optional[int] = 1):
+ _list: List[DataTrainingInfo] = []
+
+ current_page = max(1, current_page)
+ page_size = max(10, page_size)
+
+ total_count = 0
+ total_pages = 0
+
+ if name and name.strip() != "":
+ keyword_pattern = f"%{name.strip()}%"
+ parent_ids_subquery = (
+ select(DataTraining.id)
+ .where(and_(DataTraining.question.ilike(keyword_pattern), DataTraining.oid == oid)) # LIKE查询条件
+ )
+ else:
+ parent_ids_subquery = (
+ select(DataTraining.id).where(and_(DataTraining.oid == oid))
+ )
+
+ count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery())
+ total_count = session.execute(count_stmt).scalar()
+ total_pages = (total_count + page_size - 1) // page_size
+
+ if current_page > total_pages:
+ current_page = 1
+
+ paginated_parent_ids = (
+ parent_ids_subquery
+ .order_by(DataTraining.create_time.desc())
+ .offset((current_page - 1) * page_size)
+ .limit(page_size)
+ .subquery()
+ )
+
+ stmt = (
+ select(
+ DataTraining.id,
+ DataTraining.oid,
+ DataTraining.datasource,
+ CoreDatasource.name,
+ DataTraining.question,
+ DataTraining.create_time,
+ DataTraining.description,
+ )
+ .outerjoin(CoreDatasource, and_(DataTraining.datasource == CoreDatasource.id))
+ .where(and_(DataTraining.id.in_(paginated_parent_ids)))
+ .order_by(DataTraining.create_time.desc())
+ )
+
+ result = session.execute(stmt)
+
+ for row in result:
+ _list.append(DataTrainingInfo(
+ id=row.id,
+ oid=row.oid,
+ datasource=row.datasource,
+ datasource_name=row.name,
+ question=row.question,
+ create_time=row.create_time,
+ description=row.description,
+ ))
+
+ return current_page, page_size, total_count, total_pages, _list
+
+
+def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans):
+ create_time = datetime.datetime.now()
+ if info.datasource is None:
+ raise Exception(trans("i18n_data_training.datasource_cannot_be_none"))
+ parent = DataTraining(question=info.question, create_time=create_time, description=info.description, oid=oid,
+ datasource=info.datasource)
+
+ exists = session.query(
+ session.query(DataTraining).filter(
+ and_(DataTraining.question == info.question, DataTraining.oid == oid,
+ DataTraining.datasource == info.datasource)).exists()).scalar()
+ if exists:
+ raise Exception(trans("i18n_data_training.exists_in_db"))
+
+ result = DataTraining(**parent.model_dump())
+
+ session.add(parent)
+ session.flush()
+ session.refresh(parent)
+
+ result.id = parent.id
+ session.commit()
+
+ # embedding
+ run_save_data_training_embeddings([result.id])
+
+ return result.id
+
+
+def update_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans):
+ if info.datasource is None:
+ raise Exception(trans("i18n_data_training.datasource_cannot_be_none"))
+
+ count = session.query(DataTraining).filter(
+ DataTraining.id == info.id
+ ).count()
+ if count == 0:
+ raise Exception(trans('i18n_data_training.data_training_not_exists'))
+
+ exists = session.query(
+ session.query(DataTraining).filter(
+ and_(DataTraining.question == info.question, DataTraining.oid == oid,
+ DataTraining.datasource == info.datasource,
+ DataTraining.id != info.id)).exists()).scalar()
+ if exists:
+ raise Exception(trans("i18n_data_training.exists_in_db"))
+
+ stmt = update(DataTraining).where(and_(DataTraining.id == info.id)).values(
+ question=info.question,
+ description=info.description,
+ datasource=info.datasource,
+ )
+ session.execute(stmt)
+ session.commit()
+
+ # embedding
+ run_save_data_training_embeddings([info.id])
+
+ return info.id
+
+
+def delete_training(session: SessionDep, ids: list[int]):
+ stmt = delete(DataTraining).where(and_(DataTraining.id.in_(ids)))
+ session.execute(stmt)
+ session.commit()
+
+
+# def run_save_embeddings(ids: List[int]):
+# executor.submit(save_embeddings, ids)
+#
+#
+# def fill_empty_embeddings():
+# executor.submit(run_fill_empty_embeddings)
+
+
+def run_fill_empty_embeddings(session: Session):
+ if not settings.EMBEDDING_ENABLED:
+ return
+
+ stmt = select(DataTraining.id).where(and_(DataTraining.embedding.is_(None)))
+ results = session.execute(stmt).scalars().all()
+
+ save_embeddings(session, results)
+
+
+def save_embeddings(session: Session, ids: List[int]):
+ if not settings.EMBEDDING_ENABLED:
+ return
+
+ if not ids or len(ids) == 0:
+ return
+ try:
+
+ _list = session.query(DataTraining).filter(and_(DataTraining.id.in_(ids))).all()
+
+ _question_list = [item.question for item in _list]
+
+ model = EmbeddingModelCache.get_model()
+
+ results = model.embed_documents(_question_list)
+
+ for index in range(len(results)):
+ item = results[index]
+ stmt = update(DataTraining).where(and_(DataTraining.id == _list[index].id)).values(embedding=item)
+ session.execute(stmt)
+ session.commit()
+
+ except Exception:
+ traceback.print_exc()
+
+
+embedding_sql = f"""
+SELECT id, datasource, question, similarity
+FROM
+(SELECT id, datasource, question, oid,
+( 1 - (embedding <=> :embedding_array) ) AS similarity
+FROM data_training AS child
+) TEMP
+WHERE similarity > {settings.EMBEDDING_DATA_TRAINING_SIMILARITY} and oid = :oid and datasource = :datasource
+ORDER BY similarity DESC
+LIMIT {settings.EMBEDDING_DATA_TRAINING_TOP_COUNT}
+"""
+
+
+def select_training_by_question(session: SessionDep, question: str, oid: int, datasource: int):
+ if question.strip() == "":
+ return []
+
+ _list: List[DataTraining] = []
+
+ # maybe use label later?
+ stmt = (
+ select(
+ DataTraining.id,
+ DataTraining.question,
+ )
+ .where(
+ and_(or_(text(":sentence ILIKE '%' || question || '%'"), text("question ILIKE '%' || :sentence || '%'")),
+ DataTraining.oid == oid,
+ DataTraining.datasource == datasource)
+ )
+ )
+
+ results = session.execute(stmt, {'sentence': question}).fetchall()
+
+ for row in results:
+ _list.append(DataTraining(id=row.id, question=row.question))
+
+ if settings.EMBEDDING_ENABLED:
+ try:
+ model = EmbeddingModelCache.get_model()
+
+ embedding = model.embed_query(question)
+
+ results = session.execute(text(embedding_sql),
+ {'embedding_array': str(embedding), 'oid': oid, 'datasource': datasource})
+
+ for row in results:
+ _list.append(DataTraining(id=row.id, question=row.question))
+
+ except Exception:
+ traceback.print_exc()
+
+ _map: dict = {}
+ _ids: list[int] = []
+ for row in _list:
+ if row.id in _ids:
+ continue
+ else:
+ _ids.append(row.id)
+
+ if len(_ids) == 0:
+ return []
+
+ t_list = session.query(DataTraining.id, DataTraining.datasource, DataTraining.question,
+ DataTraining.description).filter(
+ and_(DataTraining.id.in_(_ids))).all()
+
+ for row in t_list:
+ _map[row.id] = {'question': row.question, 'suggestion-answer': row.description}
+
+ _results: list[dict] = []
+ for key in _map.keys():
+ _results.append(_map.get(key))
+
+ return _results
+
+
+def to_xml_string(_dict: list[dict] | dict, root: str = 'sql-examples') -> str:
+ item_name_func = lambda x: 'sql-example' if x == 'sql-examples' else 'item'
+ dicttoxml.LOG.setLevel(logging.ERROR)
+ xml = dicttoxml.dicttoxml(_dict,
+ cdata=['question', 'suggestion-answer'],
+ custom_root=root,
+ item_func=item_name_func,
+ xml_declaration=False,
+ encoding='utf-8',
+ attr_type=False).decode('utf-8')
+ pretty_xml = parseString(xml).toprettyxml()
+
+ if pretty_xml.startswith('') + 1
+ pretty_xml = pretty_xml[end_index:].lstrip()
+
+ # 替换所有 XML 转义字符
+ escape_map = {
+ '<': '<',
+ '>': '>',
+ '&': '&',
+ '"': '"',
+ ''': "'"
+ }
+ for escaped, original in escape_map.items():
+ pretty_xml = pretty_xml.replace(escaped, original)
+
+ return pretty_xml
+
+
+def get_training_template(session: SessionDep, question: str, datasource: int, oid: Optional[int] = 1) -> str:
+ if not oid:
+ oid = 1
+ if not datasource:
+ return ''
+ _results = select_training_by_question(session, question, oid, datasource)
+ if _results and len(_results) > 0:
+ data_training = to_xml_string(_results)
+ template = get_base_data_training_template().format(data_training=data_training)
+ return template
+ else:
+ return ''
diff --git a/backend/apps/data_training/models/__init__.py b/backend/apps/data_training/models/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/backend/apps/data_training/models/data_training_model.py b/backend/apps/data_training/models/data_training_model.py
new file mode 100644
index 00000000..b2806400
--- /dev/null
+++ b/backend/apps/data_training/models/data_training_model.py
@@ -0,0 +1,28 @@
+from datetime import datetime
+from typing import List, Optional
+
+from pgvector.sqlalchemy import VECTOR
+from pydantic import BaseModel
+from sqlalchemy import Column, Text, BigInteger, DateTime, Identity
+from sqlmodel import SQLModel, Field
+
+
+class DataTraining(SQLModel, table=True):
+ __tablename__ = "data_training"
+ id: Optional[int] = Field(sa_column=Column(BigInteger, Identity(always=True), primary_key=True))
+ oid: Optional[int] = Field(sa_column=Column(BigInteger, nullable=True, default=1))
+ datasource: Optional[int] = Field(sa_column=Column(BigInteger, nullable=True))
+ create_time: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=False), nullable=True))
+ question: Optional[str] = Field(max_length=255)
+ description: Optional[str] = Field(sa_column=Column(Text, nullable=True))
+ embedding: Optional[List[float]] = Field(sa_column=Column(VECTOR(), nullable=True))
+
+
+class DataTrainingInfo(BaseModel):
+ id: Optional[int] = None
+ oid: Optional[int] = None
+ datasource: Optional[int] = None
+ datasource_name: Optional[str] = None
+ create_time: Optional[datetime] = None
+ question: Optional[str] = None
+ description: Optional[str] = None
diff --git a/backend/apps/datasource/api/datasource.py b/backend/apps/datasource/api/datasource.py
index 69cad28f..9c182e34 100644
--- a/backend/apps/datasource/api/datasource.py
+++ b/backend/apps/datasource/api/datasource.py
@@ -193,6 +193,7 @@ def inner():
return await asyncio.to_thread(inner)
+# not used
@router.post("/fieldEnum/{id}")
async def field_enum(session: SessionDep, id: int):
def inner():
diff --git a/backend/apps/datasource/api/table_relation.py b/backend/apps/datasource/api/table_relation.py
new file mode 100644
index 00000000..6a65060a
--- /dev/null
+++ b/backend/apps/datasource/api/table_relation.py
@@ -0,0 +1,29 @@
+# Author: Junjun
+# Date: 2025/9/24
+from typing import List
+
+from fastapi import APIRouter
+
+from apps.datasource.models.datasource import CoreDatasource
+from common.core.deps import SessionDep
+
+router = APIRouter(tags=["table_relation"], prefix="/table_relation")
+
+
+@router.post("/save/{ds_id}")
+async def save_relation(session: SessionDep, ds_id: int, relation: List[dict]):
+ ds = session.get(CoreDatasource, ds_id)
+ if ds:
+ ds.table_relation = relation
+ session.commit()
+ else:
+ raise Exception("no datasource")
+ return True
+
+
+@router.post("/get/{ds_id}")
+async def save_relation(session: SessionDep, ds_id: int):
+ ds = session.get(CoreDatasource, ds_id)
+ if ds:
+ return ds.table_relation if ds.table_relation else []
+ return []
diff --git a/backend/apps/datasource/crud/datasource.py b/backend/apps/datasource/crud/datasource.py
index b4720bad..4380b3cf 100644
--- a/backend/apps/datasource/crud/datasource.py
+++ b/backend/apps/datasource/crud/datasource.py
@@ -4,14 +4,16 @@
from fastapi import HTTPException
from sqlalchemy import and_, text
+from sqlbot_xpack.permissions.models.ds_rules import DsRules
from sqlmodel import select
from apps.datasource.crud.permission import get_column_permission_fields, get_row_permission_filters, is_normal_user
+from apps.datasource.embedding.table_embedding import get_table_embedding
from apps.datasource.utils.utils import aes_decrypt
from apps.db.constant import DB
from apps.db.db import get_tables, get_fields, exec_sql, check_connection
from apps.db.engine import get_engine_config, get_engine_conn
-from apps.db.type import db_type_relation
+from common.core.config import settings
from common.core.deps import SessionDep, CurrentUser, Trans
from common.utils.utils import deepcopy_ignore_extra
from .table import get_tables_by_ds_id
@@ -70,7 +72,7 @@ def create_ds(session: SessionDep, trans: Trans, user: CurrentUser, create_ds: C
ds.create_by = user.id
ds.oid = user.oid if user.oid is not None else 1
ds.status = "Success"
- ds.type_name = db_type_relation()[ds.type]
+ ds.type_name = DB.get_db(ds.type).db_name
record = CoreDatasource(**ds.model_dump())
session.add(record)
session.flush()
@@ -250,8 +252,9 @@ def preview(session: SessionDep, current_user: CurrentUser, id: int, data: Table
f_list = [f for f in data.fields if f.checked]
if is_normal_user(current_user):
# column is checked, and, column permission for data.fields
+ contain_rules = session.query(DsRules).all()
f_list = get_column_permission_fields(session=session, current_user=current_user, table=data.table,
- fields=f_list)
+ fields=f_list, contain_rules=contain_rules)
# row permission tree
where_str = ''
@@ -297,7 +300,7 @@ def preview(session: SessionDep, current_user: CurrentUser, id: int, data: Table
sql = f"""SELECT "{'", "'.join(fields)}" FROM "{data.table.table_name}"
{where}
LIMIT 100"""
- return exec_sql(ds, sql, True, [data.table.table_name])
+ return exec_sql(ds, sql, True)
def fieldEnum(session: SessionDep, id: int):
@@ -313,7 +316,7 @@ def fieldEnum(session: SessionDep, id: int):
db = DB.get_db(ds.type)
sql = f"""SELECT DISTINCT {db.prefix}{field.field_name}{db.suffix} FROM {db.prefix}{table.table_name}{db.suffix}"""
- res = exec_sql(ds, sql, True, [table.table_name])
+ res = exec_sql(ds, sql, True)
return [item.get(res.get('fields')[0]) for item in res.get('data')]
@@ -336,42 +339,122 @@ def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: Core
tables = session.query(CoreTable).filter(CoreTable.ds_id == ds.id).all()
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
schema = conf.dbSchema if conf.dbSchema is not None and conf.dbSchema != "" else conf.database
+
+ # get all field
+ table_ids = [table.id for table in tables]
+ all_fields = session.query(CoreField).filter(
+ and_(CoreField.table_id.in_(table_ids), CoreField.checked == True)).all()
+ # build dict
+ fields_dict = {}
+ for field in all_fields:
+ if fields_dict.get(field.table_id):
+ fields_dict.get(field.table_id).append(field)
+ else:
+ fields_dict[field.table_id] = [field]
+
+ contain_rules = session.query(DsRules).all()
for table in tables:
- fields = session.query(CoreField).filter(and_(CoreField.table_id == table.id, CoreField.checked == True)).all()
+ # fields = session.query(CoreField).filter(and_(CoreField.table_id == table.id, CoreField.checked == True)).all()
+ fields = fields_dict.get(table.id)
# do column permissions, filter fields
- fields = get_column_permission_fields(session=session, current_user=current_user, table=table, fields=fields)
+ fields = get_column_permission_fields(session=session, current_user=current_user, table=table, fields=fields,
+ contain_rules=contain_rules)
_list.append(TableAndFields(schema=schema, table=table, fields=fields))
return _list
-def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource) -> str:
+def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, question: str,
+ embedding: bool = True) -> str:
schema_str = ""
table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds)
if len(table_objs) == 0:
return schema_str
db_name = table_objs[0].schema
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
+ tables = []
+ all_tables = [] # temp save all tables
for obj in table_objs:
- schema_str += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {obj.table.table_name}"
+ schema_table = ''
+ schema_table += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {obj.table.table_name}"
table_comment = ''
if obj.table.custom_comment:
table_comment = obj.table.custom_comment.strip()
if table_comment == '':
- schema_str += '\n[\n'
+ schema_table += '\n[\n'
else:
- schema_str += f", {table_comment}\n[\n"
-
- field_list = []
- for field in obj.fields:
- field_comment = ''
- if field.custom_comment:
- field_comment = field.custom_comment.strip()
- if field_comment == '':
- field_list.append(f"({field.field_name}:{field.field_type})")
- else:
- field_list.append(f"({field.field_name}:{field.field_type}, {field_comment})")
- schema_str += ",\n".join(field_list)
- schema_str += '\n]\n'
- # todo 外键
+ schema_table += f", {table_comment}\n[\n"
+
+ if obj.fields:
+ field_list = []
+ for field in obj.fields:
+ field_comment = ''
+ if field.custom_comment:
+ field_comment = field.custom_comment.strip()
+ if field_comment == '':
+ field_list.append(f"({field.field_name}:{field.field_type})")
+ else:
+ field_list.append(f"({field.field_name}:{field.field_type}, {field_comment})")
+ schema_table += ",\n".join(field_list)
+ schema_table += '\n]\n'
+
+ t_obj = {"id": obj.table.id, "schema_table": schema_table}
+ tables.append(t_obj)
+ all_tables.append(t_obj)
+
+ # do table embedding
+ if embedding and tables and settings.TABLE_EMBEDDING_ENABLED:
+ tables = get_table_embedding(session, current_user, tables, question)
+ # splice schema
+ if tables:
+ for s in tables:
+ schema_str += s.get('schema_table')
+
+ # field relation
+ if tables and ds.table_relation:
+ relations = list(filter(lambda x: x.get('shape') == 'edge', ds.table_relation))
+ if relations:
+ # Complete the missing table
+ # get tables in relation, remove irrelevant relation
+ embedding_table_ids = [s.get('id') for s in tables]
+ all_relations = list(
+ filter(lambda x: x.get('source').get('cell') in embedding_table_ids or x.get('target').get(
+ 'cell') in embedding_table_ids, relations))
+
+ # get relation table ids, sub embedding table ids
+ relation_table_ids = []
+ for r in all_relations:
+ relation_table_ids.append(r.get('source').get('cell'))
+ relation_table_ids.append(r.get('target').get('cell'))
+ relation_table_ids = list(set(relation_table_ids))
+ # get table dict
+ table_records = session.query(CoreTable).filter(CoreTable.id.in_(list(map(int, relation_table_ids)))).all()
+ table_dict = {}
+ for ele in table_records:
+ table_dict[ele.id] = ele.table_name
+
+ # get lost table ids
+ lost_table_ids = list(set(relation_table_ids) - set(embedding_table_ids))
+ # get lost table schema and splice it
+ lost_tables = list(filter(lambda x: x.get('id') in lost_table_ids, all_tables))
+ if lost_tables:
+ for s in lost_tables:
+ schema_str += s.get('schema_table')
+
+ # get field dict
+ relation_field_ids = []
+ for relation in all_relations:
+ relation_field_ids.append(relation.get('source').get('port'))
+ relation_field_ids.append(relation.get('target').get('port'))
+ relation_field_ids = list(set(relation_field_ids))
+ field_records = session.query(CoreField).filter(CoreField.id.in_(list(map(int, relation_field_ids)))).all()
+ field_dict = {}
+ for ele in field_records:
+ field_dict[ele.id] = ele.field_name
+
+ if all_relations:
+ schema_str += '【Foreign keys】\n'
+ for ele in all_relations:
+ schema_str += f"{table_dict.get(int(ele.get('source').get('cell')))}.{field_dict.get(int(ele.get('source').get('port')))}={table_dict.get(int(ele.get('target').get('cell')))}.{field_dict.get(int(ele.get('target').get('port')))}\n"
+
return schema_str
diff --git a/backend/apps/datasource/crud/permission.py b/backend/apps/datasource/crud/permission.py
index 80c32efd..87071812 100644
--- a/backend/apps/datasource/crud/permission.py
+++ b/backend/apps/datasource/crud/permission.py
@@ -2,13 +2,15 @@
from typing import List, Optional
from sqlalchemy import and_
-from apps.datasource.crud.row_permission import transFilterTree
-from apps.datasource.models.datasource import CoreDatasource, CoreField, CoreTable
-from common.core.deps import CurrentUser, SessionDep
from sqlbot_xpack.permissions.api.permission import transRecord2DTO
from sqlbot_xpack.permissions.models.ds_permission import DsPermission, PermissionDTO
from sqlbot_xpack.permissions.models.ds_rules import DsRules
+from apps.datasource.crud.row_permission import transFilterTree
+from apps.datasource.models.datasource import CoreDatasource, CoreField, CoreTable
+from common.core.deps import CurrentUser, SessionDep
+
+
def get_row_permission_filters(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource,
tables: Optional[list] = None, single_table: Optional[CoreTable] = None):
if single_table:
@@ -20,10 +22,10 @@ def get_row_permission_filters(session: SessionDep, current_user: CurrentUser, d
filters = []
if is_normal_user(current_user):
+ contain_rules = session.query(DsRules).all()
for table in table_list:
row_permissions = session.query(DsPermission).filter(
and_(DsPermission.table_id == table.id, DsPermission.type == 'row')).all()
- contain_rules = session.query(DsRules).all()
res: List[PermissionDTO] = []
if row_permissions is not None:
for permission in row_permissions:
@@ -35,6 +37,7 @@ def get_row_permission_filters(session: SessionDep, current_user: CurrentUser, d
if p_list is not None and u_list is not None and permission.id in p_list and (
current_user.id in u_list or f'{current_user.id}' in u_list):
flag = True
+ break
if flag:
res.append(transRecord2DTO(session, permission))
where_str = transFilterTree(session, res, ds)
@@ -43,11 +46,10 @@ def get_row_permission_filters(session: SessionDep, current_user: CurrentUser, d
def get_column_permission_fields(session: SessionDep, current_user: CurrentUser, table: CoreTable,
- fields: list[CoreField]):
+ fields: list[CoreField], contain_rules: list[DsRules]):
if is_normal_user(current_user):
column_permissions = session.query(DsPermission).filter(
and_(DsPermission.table_id == table.id, DsPermission.type == 'column')).all()
- contain_rules = session.query(DsRules).all()
if column_permissions is not None:
for permission in column_permissions:
# check permission and user in same rules
@@ -63,6 +65,7 @@ def get_column_permission_fields(session: SessionDep, current_user: CurrentUser,
if p_list is not None and u_list is not None and permission.id in p_list and (
current_user.id in u_list or f'{current_user.id}' in u_list):
flag = True
+ break
if flag:
permission_list = json.loads(permission.permissions)
fields = filter_list(fields, permission_list)
diff --git a/backend/apps/datasource/embedding/__init__.py b/backend/apps/datasource/embedding/__init__.py
new file mode 100644
index 00000000..87bb6f5d
--- /dev/null
+++ b/backend/apps/datasource/embedding/__init__.py
@@ -0,0 +1,2 @@
+# Author: Junjun
+# Date: 2025/9/18
diff --git a/backend/apps/datasource/embedding/ds_embedding.py b/backend/apps/datasource/embedding/ds_embedding.py
new file mode 100644
index 00000000..7aad7d29
--- /dev/null
+++ b/backend/apps/datasource/embedding/ds_embedding.py
@@ -0,0 +1,59 @@
+# Author: Junjun
+# Date: 2025/9/18
+import json
+import traceback
+from typing import Optional
+
+from apps.ai_model.embedding import EmbeddingModelCache
+from apps.datasource.crud.datasource import get_table_schema
+from apps.datasource.embedding.utils import cosine_similarity
+from apps.datasource.models.datasource import CoreDatasource
+from apps.system.crud.assistant import AssistantOutDs
+from common.core.deps import CurrentAssistant
+from common.core.deps import SessionDep, CurrentUser
+from common.utils.utils import SQLBotLogUtil
+
+
+def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, out_ds: AssistantOutDs,
+ question: str,
+ current_assistant: Optional[CurrentAssistant] = None):
+ _list = []
+ if current_assistant and current_assistant.type != 4:
+ if out_ds.ds_list:
+ for _ds in out_ds.ds_list:
+ ds = out_ds.get_ds(_ds.id)
+ table_schema = out_ds.get_db_schema(_ds.id)
+ ds_info = f"{ds.name}, {ds.description}\n"
+ ds_schema = ds_info + table_schema
+ _list.append({"id": ds.id, "ds_schema": ds_schema, "cosine_similarity": 0.0, "ds": ds})
+ else:
+ for _ds in _ds_list:
+ if _ds.get('id'):
+ ds = session.get(CoreDatasource, _ds.get('id'))
+ table_schema = get_table_schema(session, current_user, ds, question, embedding=False)
+ ds_info = f"{ds.name}, {ds.description}\n"
+ ds_schema = ds_info + table_schema
+ _list.append({"id": ds.id, "ds_schema": ds_schema, "cosine_similarity": 0.0, "ds": ds})
+
+ if _list:
+ try:
+ text = [s.get('ds_schema') for s in _list]
+
+ model = EmbeddingModelCache.get_model()
+ results = model.embed_documents(text)
+
+ q_embedding = model.embed_query(question)
+ for index in range(len(results)):
+ item = results[index]
+ _list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item)
+
+ _list.sort(key=lambda x: x['cosine_similarity'], reverse=True)
+ # print(len(_list))
+ SQLBotLogUtil.info(json.dumps(
+ [{"id": ele.get("id"), "name": ele.get("ds").name, "cosine_similarity": ele.get("cosine_similarity")}
+ for ele in _list]))
+ ds = _list[0].get('ds')
+ return {"id": ds.id, "name": ds.name, "description": ds.description}
+ except Exception:
+ traceback.print_exc()
+ return _list
diff --git a/backend/apps/datasource/embedding/table_embedding.py b/backend/apps/datasource/embedding/table_embedding.py
new file mode 100644
index 00000000..1a3fe896
--- /dev/null
+++ b/backend/apps/datasource/embedding/table_embedding.py
@@ -0,0 +1,41 @@
+# Author: Junjun
+# Date: 2025/9/23
+import json
+import time
+import traceback
+
+from apps.ai_model.embedding import EmbeddingModelCache
+from apps.datasource.embedding.utils import cosine_similarity
+from common.core.config import settings
+from common.core.deps import SessionDep, CurrentUser
+from common.utils.utils import SQLBotLogUtil
+
+
+def get_table_embedding(session: SessionDep, current_user: CurrentUser, tables: list[dict], question: str):
+ _list = []
+ for table in tables:
+ _list.append({"id": table.get('id'), "schema_table": table.get('schema_table'), "cosine_similarity": 0.0})
+
+ if _list:
+ try:
+ text = [s.get('schema_table') for s in _list]
+
+ model = EmbeddingModelCache.get_model()
+ start_time = time.time()
+ results = model.embed_documents(text)
+ end_time = time.time()
+ SQLBotLogUtil.info(str(end_time - start_time))
+
+ q_embedding = model.embed_query(question)
+ for index in range(len(results)):
+ item = results[index]
+ _list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item)
+
+ _list.sort(key=lambda x: x['cosine_similarity'], reverse=True)
+ _list = _list[:settings.TABLE_EMBEDDING_COUNT]
+ # print(len(_list))
+ SQLBotLogUtil.info(json.dumps(_list))
+ return _list
+ except Exception:
+ traceback.print_exc()
+ return _list
diff --git a/backend/apps/datasource/embedding/utils.py b/backend/apps/datasource/embedding/utils.py
new file mode 100644
index 00000000..3f6ddced
--- /dev/null
+++ b/backend/apps/datasource/embedding/utils.py
@@ -0,0 +1,18 @@
+# Author: Junjun
+# Date: 2025/9/23
+import math
+
+
+def cosine_similarity(vec_a, vec_b):
+ if len(vec_a) != len(vec_b):
+ raise ValueError("The vector dimension must be the same")
+
+ dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
+
+ norm_a = math.sqrt(sum(a * a for a in vec_a))
+ norm_b = math.sqrt(sum(b * b for b in vec_b))
+
+ if norm_a == 0 or norm_b == 0:
+ return 0.0
+
+ return dot_product / (norm_a * norm_b)
diff --git a/backend/apps/datasource/models/datasource.py b/backend/apps/datasource/models/datasource.py
index 6496a857..78f23916 100644
--- a/backend/apps/datasource/models/datasource.py
+++ b/backend/apps/datasource/models/datasource.py
@@ -2,13 +2,14 @@
from typing import List, Optional
from pydantic import BaseModel
-from sqlalchemy import Column, Text, BigInteger, DateTime, Integer, Identity
+from sqlalchemy import Column, Text, BigInteger, DateTime, Identity
+from sqlalchemy.dialects.postgresql import JSONB
from sqlmodel import SQLModel, Field
class CoreDatasource(SQLModel, table=True):
__tablename__ = "core_datasource"
- id: int = Field(sa_column=Column(Integer, Identity(always=True), nullable=False, primary_key=True))
+ id: int = Field(sa_column=Column(BigInteger, Identity(always=True), nullable=False, primary_key=True))
name: str = Field(max_length=128, nullable=False)
description: str = Field(max_length=512, nullable=True)
type: str = Field(max_length=64)
@@ -19,11 +20,12 @@ class CoreDatasource(SQLModel, table=True):
status: str = Field(max_length=64, nullable=True)
num: str = Field(max_length=256, nullable=True)
oid: int = Field(sa_column=Column(BigInteger()))
+ table_relation: List = Field(sa_column=Column(JSONB, nullable=True))
class CoreTable(SQLModel, table=True):
__tablename__ = "core_table"
- id: int = Field(sa_column=Column(Integer, Identity(always=True), nullable=False, primary_key=True))
+ id: int = Field(sa_column=Column(BigInteger, Identity(always=True), nullable=False, primary_key=True))
ds_id: int = Field(sa_column=Column(BigInteger()))
checked: bool = Field(default=True)
table_name: str = Field(sa_column=Column(Text))
@@ -33,7 +35,7 @@ class CoreTable(SQLModel, table=True):
class CoreField(SQLModel, table=True):
__tablename__ = "core_field"
- id: int = Field(sa_column=Column(Integer, Identity(always=True), nullable=False, primary_key=True))
+ id: int = Field(sa_column=Column(BigInteger, Identity(always=True), nullable=False, primary_key=True))
ds_id: int = Field(sa_column=Column(BigInteger()))
table_id: int = Field(sa_column=Column(BigInteger()))
checked: bool = Field(default=True)
diff --git a/backend/apps/db/constant.py b/backend/apps/db/constant.py
index 67074ac1..4194d8b7 100644
--- a/backend/apps/db/constant.py
+++ b/backend/apps/db/constant.py
@@ -13,19 +13,20 @@ def __init__(self, type_name):
class DB(Enum):
- mysql = ('mysql', '`', '`', ConnectType.sqlalchemy)
- sqlServer = ('sqlServer', '[', ']', ConnectType.sqlalchemy)
- pg = ('pg', '"', '"', ConnectType.sqlalchemy)
- excel = ('excel', '"', '"', ConnectType.sqlalchemy)
- oracle = ('oracle', '"', '"', ConnectType.sqlalchemy)
- ck = ('ck', '"', '"', ConnectType.sqlalchemy)
- dm = ('dm', '"', '"', ConnectType.py_driver)
- doris = ('doris', '`', '`', ConnectType.py_driver)
- redshift = ('redshift', '"', '"', ConnectType.py_driver)
- es = ('es', '"', '"', ConnectType.py_driver)
-
- def __init__(self, type, prefix, suffix, connect_type: ConnectType):
+ mysql = ('mysql', 'MySQL', '`', '`', ConnectType.sqlalchemy)
+ sqlServer = ('sqlServer', 'Microsoft SQL Server', '[', ']', ConnectType.sqlalchemy)
+ pg = ('pg', 'PostgreSQL', '"', '"', ConnectType.sqlalchemy)
+ excel = ('excel', 'Excel/CSV', '"', '"', ConnectType.sqlalchemy)
+ oracle = ('oracle', 'Oracle', '"', '"', ConnectType.sqlalchemy)
+ ck = ('ck', 'ClickHouse', '"', '"', ConnectType.sqlalchemy)
+ dm = ('dm', '达梦', '"', '"', ConnectType.py_driver)
+ doris = ('doris', 'Apache Doris', '`', '`', ConnectType.py_driver)
+ redshift = ('redshift', 'AWS Redshift', '"', '"', ConnectType.py_driver)
+ es = ('es', 'Elasticsearch', '"', '"', ConnectType.py_driver)
+
+ def __init__(self, type, db_name, prefix, suffix, connect_type: ConnectType):
self.type = type
+ self.db_name = db_name
self.prefix = prefix
self.suffix = suffix
self.connect_type = connect_type
diff --git a/backend/apps/db/db.py b/backend/apps/db/db.py
index a1da1492..1a22dc35 100644
--- a/backend/apps/db/db.py
+++ b/backend/apps/db/db.py
@@ -5,6 +5,8 @@
from decimal import Decimal
from typing import Optional
+import pymssql
+
from apps.db.db_sql import get_table_sql, get_field_sql, get_version_sql
from common.error import ParseSQLResultError
@@ -24,7 +26,7 @@
from common.core.deps import Trans
from common.utils.utils import SQLBotLogUtil
from fastapi import HTTPException
-from apps.db.es_engine import get_es_connect, get_es_index, get_es_fields, get_es_data
+from apps.db.es_engine import get_es_connect, get_es_index, get_es_fields, get_es_data_by_http
def get_uri(ds: CoreDatasource) -> str:
@@ -70,6 +72,35 @@ def get_uri_from_config(type: str, conf: DatasourceConf) -> str:
return db_url
+def get_extra_config(conf: DatasourceConf):
+ config_dict = {}
+ if conf.extraJdbc:
+ config_arr = conf.extraJdbc.split("&")
+ for config in config_arr:
+ kv = config.split("=")
+ if len(kv) == 2 and kv[0] and kv[1]:
+ config_dict[kv[0]] = kv[1]
+ else:
+ raise Exception(f'param: {config} is error')
+ return config_dict
+
+
+def get_origin_connect(type: str, conf: DatasourceConf):
+ extra_config_dict = get_extra_config(conf)
+ if type == "sqlServer":
+ return pymssql.connect(
+ server=conf.host,
+ port=str(conf.port),
+ user=conf.username,
+ password=conf.password,
+ database=conf.database,
+ timeout=conf.timeout,
+ tds_version='7.0', # options: '4.2', '7.0', '8.0' ...,
+ **extra_config_dict
+ )
+
+
+# use sqlalchemy
def get_engine(ds: CoreDatasource, timeout: int = 0) -> Engine:
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
if conf.timeout is None:
@@ -87,7 +118,8 @@ def get_engine(ds: CoreDatasource, timeout: int = 0) -> Engine:
connect_args={"connect_timeout": conf.timeout},
pool_timeout=conf.timeout)
elif ds.type == 'sqlServer':
- engine = create_engine(get_uri(ds), pool_timeout=conf.timeout)
+ engine = create_engine('mssql+pymssql://', creator=lambda: get_origin_connect(ds.type, conf),
+ pool_timeout=conf.timeout)
elif ds.type == 'oracle':
engine = create_engine(get_uri(ds),
pool_timeout=conf.timeout)
@@ -119,9 +151,10 @@ def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDs
return False
else:
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration)))
+ extra_config_dict = get_extra_config(conf)
if ds.type == 'dm':
with dmPython.connect(user=conf.username, password=conf.password, server=conf.host,
- port=conf.port) as conn, conn.cursor() as cursor:
+ port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor:
try:
cursor.execute('select 1', timeout=10).fetchall()
SQLBotLogUtil.info("success")
@@ -134,7 +167,7 @@ def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDs
elif ds.type == 'doris':
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
port=conf.port, db=conf.database, connect_timeout=10,
- read_timeout=10) as conn, conn.cursor() as cursor:
+ read_timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor:
try:
cursor.execute('select 1')
SQLBotLogUtil.info("success")
@@ -148,7 +181,7 @@ def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDs
with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database,
user=conf.username,
password=conf.password,
- timeout=10) as conn, conn.cursor() as cursor:
+ timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor:
try:
cursor.execute('select 1')
SQLBotLogUtil.info("success")
@@ -205,16 +238,17 @@ def get_version(ds: CoreDatasource | AssistantOutDsSchema):
res = result.fetchall()
version = res[0][0]
else:
+ extra_config_dict = get_extra_config(conf)
if ds.type == 'dm':
with dmPython.connect(user=conf.username, password=conf.password, server=conf.host,
port=conf.port) as conn, conn.cursor() as cursor:
- cursor.execute(sql, timeout=10)
+ cursor.execute(sql, timeout=10, **extra_config_dict)
res = cursor.fetchall()
version = res[0][0]
elif ds.type == 'doris':
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
port=conf.port, db=conf.database, connect_timeout=10,
- read_timeout=10) as conn, conn.cursor() as cursor:
+ read_timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor:
cursor.execute(sql)
res = cursor.fetchall()
version = res[0][0]
@@ -233,29 +267,30 @@ def get_schema(ds: CoreDatasource):
with get_session(ds) as session:
sql: str = ''
if ds.type == "sqlServer":
- sql = f"""select name from sys.schemas"""
+ sql = """select name from sys.schemas"""
elif ds.type == "pg" or ds.type == "excel":
sql = """SELECT nspname
FROM pg_namespace"""
elif ds.type == "oracle":
- sql = f"""select * from all_users"""
+ sql = """select * from all_users"""
with session.execute(text(sql)) as result:
res = result.fetchall()
res_list = [item[0] for item in res]
return res_list
else:
+ extra_config_dict = get_extra_config(conf)
if ds.type == 'dm':
with dmPython.connect(user=conf.username, password=conf.password, server=conf.host,
- port=conf.port) as conn, conn.cursor() as cursor:
- cursor.execute(f"""select OBJECT_NAME from dba_objects where object_type='SCH'""", timeout=conf.timeout)
+ port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor:
+ cursor.execute("""select OBJECT_NAME from dba_objects where object_type='SCH'""", timeout=conf.timeout)
res = cursor.fetchall()
res_list = [item[0] for item in res]
return res_list
elif ds.type == 'redshift':
with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username,
password=conf.password,
- timeout=conf.timeout) as conn, conn.cursor() as cursor:
- cursor.execute(f"""SELECT nspname FROM pg_namespace""")
+ timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor:
+ cursor.execute("""SELECT nspname FROM pg_namespace""")
res = cursor.fetchall()
res_list = [item[0] for item in res]
return res_list
@@ -264,34 +299,35 @@ def get_schema(ds: CoreDatasource):
def get_tables(ds: CoreDatasource):
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
db = DB.get_db(ds.type)
- sql = get_table_sql(ds, conf, get_version(ds))
+ sql, sql_param = get_table_sql(ds, conf, get_version(ds))
if db.connect_type == ConnectType.sqlalchemy:
with get_session(ds) as session:
- with session.execute(text(sql)) as result:
+ with session.execute(text(sql), {"param": sql_param}) as result:
res = result.fetchall()
res_list = [TableSchema(*item) for item in res]
return res_list
else:
+ extra_config_dict = get_extra_config(conf)
if ds.type == 'dm':
with dmPython.connect(user=conf.username, password=conf.password, server=conf.host,
- port=conf.port) as conn, conn.cursor() as cursor:
- cursor.execute(sql, timeout=conf.timeout)
+ port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor:
+ cursor.execute(sql, {"param": sql_param}, timeout=conf.timeout)
res = cursor.fetchall()
res_list = [TableSchema(*item) for item in res]
return res_list
elif ds.type == 'doris':
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
port=conf.port, db=conf.database, connect_timeout=conf.timeout,
- read_timeout=conf.timeout) as conn, conn.cursor() as cursor:
- cursor.execute(sql)
+ read_timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor:
+ cursor.execute(sql, (sql_param,))
res = cursor.fetchall()
res_list = [TableSchema(*item) for item in res]
return res_list
elif ds.type == 'redshift':
with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username,
password=conf.password,
- timeout=conf.timeout) as conn, conn.cursor() as cursor:
- cursor.execute(sql)
+ timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor:
+ cursor.execute(sql, (sql_param,))
res = cursor.fetchall()
res_list = [TableSchema(*item) for item in res]
return res_list
@@ -304,34 +340,35 @@ def get_tables(ds: CoreDatasource):
def get_fields(ds: CoreDatasource, table_name: str = None):
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
db = DB.get_db(ds.type)
- sql = get_field_sql(ds, conf, table_name)
+ sql, p1, p2 = get_field_sql(ds, conf, table_name)
if db.connect_type == ConnectType.sqlalchemy:
with get_session(ds) as session:
- with session.execute(text(sql)) as result:
+ with session.execute(text(sql), {"param1": p1, "param2": p2}) as result:
res = result.fetchall()
res_list = [ColumnSchema(*item) for item in res]
return res_list
else:
+ extra_config_dict = get_extra_config(conf)
if ds.type == 'dm':
with dmPython.connect(user=conf.username, password=conf.password, server=conf.host,
- port=conf.port) as conn, conn.cursor() as cursor:
- cursor.execute(sql, timeout=conf.timeout)
+ port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor:
+ cursor.execute(sql, {"param1": p1, "param2": p2}, timeout=conf.timeout)
res = cursor.fetchall()
res_list = [ColumnSchema(*item) for item in res]
return res_list
elif ds.type == 'doris':
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
port=conf.port, db=conf.database, connect_timeout=conf.timeout,
- read_timeout=conf.timeout) as conn, conn.cursor() as cursor:
- cursor.execute(sql)
+ read_timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor:
+ cursor.execute(sql, (p1, p2))
res = cursor.fetchall()
res_list = [ColumnSchema(*item) for item in res]
return res_list
elif ds.type == 'redshift':
with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username,
password=conf.password,
- timeout=conf.timeout) as conn, conn.cursor() as cursor:
- cursor.execute(sql)
+ timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor:
+ cursor.execute(sql, (p1, p2))
res = cursor.fetchall()
res_list = [ColumnSchema(*item) for item in res]
return res_list
@@ -341,7 +378,7 @@ def get_fields(ds: CoreDatasource, table_name: str = None):
return res_list
-def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=False, table_name=None):
+def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=False):
while sql.endswith(';'):
sql = sql[:-1]
@@ -363,9 +400,10 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=
raise ParseSQLResultError(str(ex))
else:
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration)))
+ extra_config_dict = get_extra_config(conf)
if ds.type == 'dm':
with dmPython.connect(user=conf.username, password=conf.password, server=conf.host,
- port=conf.port) as conn, conn.cursor() as cursor:
+ port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor:
try:
cursor.execute(sql, timeout=conf.timeout)
res = cursor.fetchall()
@@ -384,7 +422,7 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=
elif ds.type == 'doris':
with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host,
port=conf.port, db=conf.database, connect_timeout=conf.timeout,
- read_timeout=conf.timeout) as conn, conn.cursor() as cursor:
+ read_timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor:
try:
cursor.execute(sql)
res = cursor.fetchall()
@@ -403,7 +441,7 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=
elif ds.type == 'redshift':
with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username,
password=conf.password,
- timeout=conf.timeout) as conn, conn.cursor() as cursor:
+ timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor:
try:
cursor.execute(sql)
res = cursor.fetchall()
@@ -421,17 +459,16 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=
raise ParseSQLResultError(str(ex))
elif ds.type == 'es':
try:
- if table_name and table_name[0]:
- res, columns = get_es_data(conf, sql, table_name[0])
- columns = [field[0] for field in columns] if origin_column else [field[0].lower() for
- field in
- columns]
- result_list = [
- {str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in
- enumerate(tuple_item)}
- for tuple_item in res
- ]
- return {"fields": columns, "data": result_list,
- "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))}
+ res, columns = get_es_data_by_http(conf, sql)
+ columns = [field.get('name') for field in columns] if origin_column else [field.get('name').lower() for
+ field in
+ columns]
+ result_list = [
+ {str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in
+ enumerate(tuple(tuple_item))}
+ for tuple_item in res
+ ]
+ return {"fields": columns, "data": result_list,
+ "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))}
except Exception as ex:
raise Exception(str(ex))
diff --git a/backend/apps/db/db_sql.py b/backend/apps/db/db_sql.py
index 10af38f8..66616b70 100644
--- a/backend/apps/db/db_sql.py
+++ b/backend/apps/db/db_sql.py
@@ -5,27 +5,27 @@
def get_version_sql(ds: CoreDatasource, conf: DatasourceConf):
if ds.type == "mysql" or ds.type == "doris":
- return f"""
+ return """
SELECT VERSION()
"""
elif ds.type == "sqlServer":
- return f"""
+ return """
select SERVERPROPERTY('ProductVersion')
"""
elif ds.type == "pg" or ds.type == "excel":
- return f"""
+ return """
SELECT current_setting('server_version')
"""
elif ds.type == "oracle":
- return f"""
+ return """
SELECT version FROM v$instance
"""
elif ds.type == "ck":
- return f"""
+ return """
select version()
"""
elif ds.type == 'dm':
- return f"""
+ return """
SELECT * FROM v$version
"""
elif ds.type == 'redshift':
@@ -33,18 +33,18 @@ def get_version_sql(ds: CoreDatasource, conf: DatasourceConf):
def get_table_sql(ds: CoreDatasource, conf: DatasourceConf, db_version: str = ''):
- if ds.type == "mysql" or ds.type == "doris":
- return f"""
+ if ds.type == "mysql":
+ return """
SELECT
TABLE_NAME,
TABLE_COMMENT
FROM
information_schema.TABLES
WHERE
- TABLE_SCHEMA = '{conf.database}'
- """
+ TABLE_SCHEMA = :param
+ """, conf.database
elif ds.type == "sqlServer":
- return f"""
+ return """
SELECT
TABLE_NAME AS [TABLE_NAME],
ISNULL(ep.value, '') AS [TABLE_COMMENT]
@@ -57,10 +57,10 @@ def get_table_sql(ds: CoreDatasource, conf: DatasourceConf, db_version: str = ''
AND ep.name = 'MS_Description'
WHERE
t.TABLE_TYPE IN ('BASE TABLE', 'VIEW')
- AND t.TABLE_SCHEMA = '{conf.dbSchema}'
- """
+ AND t.TABLE_SCHEMA = :param
+ """, conf.dbSchema
elif ds.type == "pg" or ds.type == "excel":
- return f"""
+ return """
SELECT c.relname AS TABLE_NAME,
COALESCE(d.description, obj_description(c.oid)) AS TABLE_COMMENT
FROM pg_class c
@@ -68,59 +68,59 @@ def get_table_sql(ds: CoreDatasource, conf: DatasourceConf, db_version: str = ''
pg_namespace n ON n.oid = c.relnamespace
LEFT JOIN
pg_description d ON d.objoid = c.oid AND d.objsubid = 0
- WHERE n.nspname = '{conf.dbSchema}'
+ WHERE n.nspname = :param
AND c.relkind IN ('r', 'v', 'p', 'm')
AND c.relname NOT LIKE 'pg_%'
AND c.relname NOT LIKE 'sql_%'
ORDER BY c.relname \
- """
+ """, conf.dbSchema
elif ds.type == "oracle":
- return f"""
+ return """
SELECT
t.TABLE_NAME AS "TABLE_NAME",
NVL(c.COMMENTS, '') AS "TABLE_COMMENT"
FROM (
SELECT TABLE_NAME, 'TABLE' AS OBJECT_TYPE
FROM DBA_TABLES
- WHERE OWNER = '{conf.dbSchema}'
+ WHERE OWNER = :param
UNION ALL
SELECT VIEW_NAME AS TABLE_NAME, 'VIEW' AS OBJECT_TYPE
FROM DBA_VIEWS
- WHERE OWNER = '{conf.dbSchema}'
+ WHERE OWNER = :param
) t
LEFT JOIN DBA_TAB_COMMENTS c
ON t.TABLE_NAME = c.TABLE_NAME
AND c.TABLE_TYPE = t.OBJECT_TYPE
- AND c.OWNER = '{conf.dbSchema}'
+ AND c.OWNER = :param
ORDER BY t.TABLE_NAME
- """
+ """, conf.dbSchema
elif ds.type == "ck":
version = int(db_version.split('.')[0])
if version < 22:
- return f"""
+ return """
SELECT name, null as comment
FROM system.tables
- WHERE database = '{conf.database}'
+ WHERE database = :param
AND engine NOT IN ('Dictionary')
ORDER BY name
- """
+ """, conf.database
else:
- return f"""
+ return """
SELECT name, comment
FROM system.tables
- WHERE database = '{conf.database}'
+ WHERE database = :param
AND engine NOT IN ('Dictionary')
ORDER BY name
- """
+ """, conf.database
elif ds.type == 'dm':
- return f"""
+ return """
select table_name, comments
from all_tab_comments
- where owner='{conf.dbSchema}'
+ where owner=:param
AND (table_type = 'TABLE' or table_type = 'VIEW')
- """
+ """, conf.dbSchema
elif ds.type == 'redshift':
- return f"""
+ return """
SELECT
relname AS TableName,
obj_description(relfilenode::regclass, 'pg_class') AS TableDescription
@@ -128,13 +128,25 @@ def get_table_sql(ds: CoreDatasource, conf: DatasourceConf, db_version: str = ''
pg_class
WHERE
relkind in ('r','p', 'f')
- AND relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = '{conf.dbSchema}')
- """
+ AND relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = %s)
+ """, conf.dbSchema
+ elif ds.type == "doris":
+ return """
+ SELECT
+ TABLE_NAME,
+ TABLE_COMMENT
+ FROM
+ information_schema.TABLES
+ WHERE
+ TABLE_SCHEMA = %s
+ """, conf.database
+ elif ds.type == "es":
+ return "", None
def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = None):
- if ds.type == "mysql" or ds.type == "doris":
- sql1 = f"""
+ if ds.type == "mysql":
+ sql1 = """
SELECT
COLUMN_NAME,
DATA_TYPE,
@@ -142,12 +154,12 @@ def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = No
FROM
INFORMATION_SCHEMA.COLUMNS
WHERE
- TABLE_SCHEMA = '{conf.database}'
+ TABLE_SCHEMA = :param1
"""
- sql2 = f" AND TABLE_NAME = '{table_name}'" if table_name is not None and table_name != "" else ""
- return sql1 + sql2
+ sql2 = " AND TABLE_NAME = :param2" if table_name is not None and table_name != "" else ""
+ return sql1 + sql2, conf.database, table_name
elif ds.type == "sqlServer":
- sql1 = f"""
+ sql1 = """
SELECT
COLUMN_NAME AS [COLUMN_NAME],
DATA_TYPE AS [DATA_TYPE],
@@ -160,12 +172,28 @@ def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = No
AND EP.minor_id = C.ORDINAL_POSITION
AND EP.name = 'MS_Description'
WHERE
- C.TABLE_SCHEMA = '{conf.dbSchema}'
+ C.TABLE_SCHEMA = :param1
"""
- sql2 = f" AND C.TABLE_NAME = '{table_name}'" if table_name is not None and table_name != "" else ""
- return sql1 + sql2
- elif ds.type == "pg" or ds.type == "excel" or ds.type == "redshift":
- sql1 = f"""
+ sql2 = " AND C.TABLE_NAME = :param2" if table_name is not None and table_name != "" else ""
+ return sql1 + sql2, conf.dbSchema, table_name
+ elif ds.type == "pg" or ds.type == "excel":
+ sql1 = """
+ SELECT a.attname AS COLUMN_NAME,
+ pg_catalog.format_type(a.atttypid, a.atttypmod) AS DATA_TYPE,
+ col_description(c.oid, a.attnum) AS COLUMN_COMMENT
+ FROM pg_catalog.pg_attribute a
+ JOIN
+ pg_catalog.pg_class c ON a.attrelid = c.oid
+ JOIN
+ pg_catalog.pg_namespace n ON n.oid = c.relnamespace
+ WHERE n.nspname = :param1
+ AND a.attnum > 0
+ AND NOT a.attisdropped \
+ """
+ sql2 = " AND c.relname = :param2" if table_name is not None and table_name != "" else ""
+ return sql1 + sql2, conf.dbSchema, table_name
+ elif ds.type == "redshift":
+ sql1 = """
SELECT a.attname AS COLUMN_NAME,
pg_catalog.format_type(a.atttypid, a.atttypmod) AS DATA_TYPE,
col_description(c.oid, a.attnum) AS COLUMN_COMMENT
@@ -174,14 +202,14 @@ def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = No
pg_catalog.pg_class c ON a.attrelid = c.oid
JOIN
pg_catalog.pg_namespace n ON n.oid = c.relnamespace
- WHERE n.nspname = '{conf.dbSchema}'
+ WHERE n.nspname = %s
AND a.attnum > 0
AND NOT a.attisdropped \
"""
- sql2 = f" AND c.relname = '{table_name}'" if table_name is not None and table_name != "" else ""
- return sql1 + sql2
+ sql2 = " AND c.relname = %s" if table_name is not None and table_name != "" else ""
+ return sql1 + sql2, conf.dbSchema, table_name
elif ds.type == "oracle":
- sql1 = f"""
+ sql1 = """
SELECT
col.COLUMN_NAME AS "COLUMN_NAME",
(CASE
@@ -201,23 +229,23 @@ def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = No
AND col.TABLE_NAME = com.TABLE_NAME
AND col.COLUMN_NAME = com.COLUMN_NAME
WHERE
- col.OWNER = '{conf.dbSchema}'
+ col.OWNER = :param1
"""
- sql2 = f" AND col.TABLE_NAME = '{table_name}'" if table_name is not None and table_name != "" else ""
- return sql1 + sql2
+ sql2 = " AND col.TABLE_NAME = :param2" if table_name is not None and table_name != "" else ""
+ return sql1 + sql2, conf.dbSchema, table_name
elif ds.type == "ck":
- sql1 = f"""
+ sql1 = """
SELECT
name AS COLUMN_NAME,
type AS DATA_TYPE,
comment AS COLUMN_COMMENT
FROM system.columns
- WHERE database = '{conf.database}'
+ WHERE database = :param1
"""
- sql2 = f" AND table = '{table_name}'" if table_name is not None and table_name != "" else ""
- return sql1 + sql2
+ sql2 = " AND table = :param2" if table_name is not None and table_name != "" else ""
+ return sql1 + sql2, conf.database, table_name
elif ds.type == 'dm':
- sql1 = f"""
+ sql1 = """
SELECT
c.COLUMN_NAME AS "COLUMN_NAME",
c.DATA_TYPE AS "DATA_TYPE",
@@ -230,7 +258,22 @@ def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = No
AND c.TABLE_NAME = com.TABLE_NAME
AND c.COLUMN_NAME = com.COLUMN_NAME
WHERE
- c.OWNER = '{conf.dbSchema}'
+ c.OWNER = :param1
+ """
+ sql2 = " AND c.TABLE_NAME = :param2" if table_name is not None and table_name != "" else ""
+ return sql1 + sql2, conf.dbSchema, table_name
+ elif ds.type == "doris":
+ sql1 = """
+ SELECT
+ COLUMN_NAME,
+ DATA_TYPE,
+ COLUMN_COMMENT
+ FROM
+ INFORMATION_SCHEMA.COLUMNS
+ WHERE
+ TABLE_SCHEMA = %s
"""
- sql2 = f" AND c.TABLE_NAME = '{table_name}'" if table_name is not None and table_name != "" else ""
- return sql1 + sql2
+ sql2 = " AND TABLE_NAME = %s" if table_name is not None and table_name != "" else ""
+ return sql1 + sql2, conf.database, table_name
+ elif ds.type == "es":
+ return "", None, None
diff --git a/backend/apps/db/es_engine.py b/backend/apps/db/es_engine.py
index 05c5797b..f8cf1f2f 100644
--- a/backend/apps/db/es_engine.py
+++ b/backend/apps/db/es_engine.py
@@ -2,12 +2,13 @@
# Date: 2025/9/9
import json
+from base64 import b64encode
import requests
from elasticsearch import Elasticsearch
-from fastapi import HTTPException
from apps.datasource.models.datasource import DatasourceConf
+from common.error import SingleMessageError
def get_es_connect(conf: DatasourceConf):
@@ -60,29 +61,57 @@ def get_es_fields(conf: DatasourceConf, table_name: str):
return res
-def get_es_data(conf: DatasourceConf, sql: str, table_name: str):
- r = requests.post(f"{conf.host}/_sql/translate", json={"query": sql})
- if r.json().get('error'):
- print(json.dumps(r.json()))
+# def get_es_data(conf: DatasourceConf, sql: str, table_name: str):
+# r = requests.post(f"{conf.host}/_sql/translate", json={"query": sql})
+# if r.json().get('error'):
+# print(json.dumps(r.json()))
+#
+# es_client = get_es_connect(conf)
+# response = es_client.search(
+# index=table_name,
+# body=json.dumps(r.json())
+# )
+#
+# # print(response)
+# fields = get_es_fields(conf, table_name)
+# res = []
+# for hit in response.get('hits').get('hits'):
+# item = []
+# if 'fields' in hit:
+# result = hit.get('fields') # {'title': ['Python'], 'age': [30]}
+# for field in fields:
+# v = result.get(field[0])
+# item.append(v[0]) if v else item.append(None)
+# res.append(tuple(item))
+# # print(hit['fields']['title'][0])
+# # elif '_source' in hit:
+# # print(hit.get('_source'))
+# return res, fields
- es_client = get_es_connect(conf)
- response = es_client.search(
- index=table_name,
- body=json.dumps(r.json())
- )
- # print(response)
- fields = get_es_fields(conf, table_name)
- res = []
- for hit in response.get('hits').get('hits'):
- item = []
- if 'fields' in hit:
- result = hit.get('fields') # {'title': ['Python'], 'age': [30]}
- for field in fields:
- v = result.get(field[0])
- item.append(v[0]) if v else item.append(None)
- res.append(tuple(item))
- # print(hit['fields']['title'][0])
- # elif '_source' in hit:
- # print(hit.get('_source'))
- return res, fields
+def get_es_data_by_http(conf: DatasourceConf, sql: str):
+ url = conf.host
+ while url.endswith('/'):
+ url = url[:-1]
+
+ host = f'{url}/_sql?format=json'
+ username = f"{conf.username}"
+ password = f"{conf.password}"
+
+ credentials = f"{username}:{password}"
+ encoded_credentials = b64encode(credentials.encode()).decode()
+
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": f"Basic {encoded_credentials}"
+ }
+
+ response = requests.post(host, data=json.dumps({"query": sql}), headers=headers)
+
+ # print(response.json())
+ res = response.json()
+ if res.get('error'):
+ raise SingleMessageError(json.dumps(res))
+ fields = res.get('columns')
+ result = res.get('rows')
+ return result, fields
diff --git a/backend/apps/db/type.py b/backend/apps/db/type.py
deleted file mode 100644
index 1e48dc65..00000000
--- a/backend/apps/db/type.py
+++ /dev/null
@@ -1,18 +0,0 @@
-# Author: Junjun
-# Date: 2025/5/22
-from typing import Dict
-
-
-def db_type_relation() -> Dict:
- return {
- "mysql": "MySQL",
- "sqlServer": "Microsoft SQL Server",
- "pg": "PostgreSQL",
- "excel": "Excel/CSV",
- "oracle": "Oracle",
- "ck": "ClickHouse",
- "dm": "达梦",
- "doris": "Apache Doris",
- "redshift": "AWS Redshift",
- "es": "Elasticsearch"
- }
diff --git a/backend/apps/mcp/mcp.py b/backend/apps/mcp/mcp.py
index 15878792..433e4727 100644
--- a/backend/apps/mcp/mcp.py
+++ b/backend/apps/mcp/mcp.py
@@ -1,6 +1,7 @@
# Author: Junjun
# Date: 2025/7/1
-
+import json
+import traceback
from datetime import timedelta
import jwt
@@ -10,15 +11,17 @@
from jwt.exceptions import InvalidTokenError
from pydantic import ValidationError
from sqlmodel import select
+from starlette.responses import JSONResponse
from apps.chat.api.chat import create_chat
-from apps.chat.models.chat_model import ChatMcp, CreateChat, ChatStart, McpQuestion
+from apps.chat.models.chat_model import ChatMcp, CreateChat, ChatStart, McpQuestion, McpAssistant, ChatQuestion, \
+ ChatFinishStep
from apps.chat.task.llm import LLMService
from apps.system.crud.user import authenticate
from apps.system.crud.user import get_db_user
from apps.system.models.system_model import UserWsModel
from apps.system.models.user import UserModel
-from apps.system.schemas.system_schema import BaseUserDTO
+from apps.system.schemas.system_schema import BaseUserDTO, AssistantHeader
from apps.system.schemas.system_schema import UserInfoDTO
from common.core import security
from common.core.config import settings
@@ -106,8 +109,93 @@ async def mcp_question(session: SessionDep, chat: McpQuestion):
raise HTTPException(status_code=400, detail="Inactive user")
mcp_chat = ChatMcp(token=chat.token, chat_id=chat.chat_id, question=chat.question)
- # ask
- llm_service = await LLMService.create(session_user, mcp_chat)
- llm_service.init_record()
- return StreamingResponse(llm_service.run_task(False), media_type="text/event-stream")
+ try:
+ llm_service = await LLMService.create(session_user, mcp_chat)
+ llm_service.init_record()
+ llm_service.run_task_async(False, chat.stream)
+ except Exception as e:
+ traceback.print_exc()
+
+ if chat.stream:
+ def _err(_e: Exception):
+ yield str(_e) + '\n\n'
+
+ return StreamingResponse(_err(e), media_type="text/event-stream")
+ else:
+ return JSONResponse(
+ content={'message': str(e)},
+ status_code=500,
+ )
+ if chat.stream:
+ return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")
+ else:
+ res = llm_service.await_result()
+ raw_data = {}
+ for chunk in res:
+ if chunk:
+ raw_data = chunk
+ status_code = 200
+ if not raw_data.get('success'):
+ status_code = 500
+
+ return JSONResponse(
+ content=raw_data,
+ status_code=status_code,
+ )
+
+
+@router.post("/mcp_assistant", operation_id="mcp_assistant")
+async def mcp_assistant(session: SessionDep, chat: McpAssistant):
+ session_user = BaseUserDTO(**{
+ "id": -1, "account": 'sqlbot-mcp-assistant', "oid": 1, "assistant_id": -1, "password": '', "language": "zh-CN"
+ })
+ # session_user: UserModel = get_db_user(session=session, user_id=1)
+ # session_user.oid = 1
+ c = create_chat(session, session_user, CreateChat(origin=1), False)
+
+ # build assistant param
+ configuration = {"endpoint": chat.url}
+ # authorization = [{"key": "x-de-token",
+ # "value": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ1aWQiOjEsIm9pZCI6MSwiZXhwIjoxNzU4NTEyMDA2fQ.3NR-pgnADLdXZtI3dXX5-LuxfGYRvYD9kkr2de7KRP0",
+ # "target": "header"}]
+ mcp_assistant_header = AssistantHeader(id=1, name='mcp_assist', domain='', type=1,
+ configuration=json.dumps(configuration),
+ certificate=chat.authorization)
+
+ # assistant question
+ mcp_chat = ChatQuestion(chat_id=c.id, question=chat.question)
+ # ask
+ try:
+ llm_service = await LLMService.create(session_user, mcp_chat, mcp_assistant_header)
+ llm_service.init_record()
+ llm_service.run_task_async(False, chat.stream, ChatFinishStep.QUERY_DATA)
+ except Exception as e:
+ traceback.print_exc()
+
+ if chat.stream:
+ def _err(_e: Exception):
+ yield str(_e) + '\n\n'
+
+ return StreamingResponse(_err(e), media_type="text/event-stream")
+ else:
+ return JSONResponse(
+ content={'message': str(e)},
+ status_code=500,
+ )
+ if chat.stream:
+ return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")
+ else:
+ res = llm_service.await_result()
+ raw_data = {}
+ for chunk in res:
+ if chunk:
+ raw_data = chunk
+ status_code = 200
+ if not raw_data.get('success'):
+ status_code = 500
+
+ return JSONResponse(
+ content=raw_data,
+ status_code=status_code,
+ )
diff --git a/backend/apps/system/crud/assistant.py b/backend/apps/system/crud/assistant.py
index 01e0c6f0..912218b9 100644
--- a/backend/apps/system/crud/assistant.py
+++ b/backend/apps/system/crud/assistant.py
@@ -211,7 +211,7 @@ def get_ds_engine(ds: AssistantOutDsSchema) -> Engine:
password=ds.password,
database=ds.dataBase,
driver='',
- extraJdbc=ds.extraParams,
+ extraJdbc=ds.extraParams or '',
dbSchema=ds.db_schema or ''
)
conf.extraJdbc = ''
diff --git a/backend/apps/system/schemas/system_schema.py b/backend/apps/system/schemas/system_schema.py
index 79eca6f4..728ca480 100644
--- a/backend/apps/system/schemas/system_schema.py
+++ b/backend/apps/system/schemas/system_schema.py
@@ -1,8 +1,9 @@
+import re
from typing import Optional
+
from pydantic import BaseModel, Field, field_validator
from common.core.schemas import BaseCreatorDTO
-import re
EMAIL_REGEX = re.compile(
r"^[a-zA-Z0-9]+([._-][a-zA-Z0-9]+)*@"
@@ -14,106 +15,121 @@
r"(?=.*[~!@#$%^&*()_+\-={}|:\"<>?`\[\];',./])"
r"[A-Za-z\d~!@#$%^&*()_+\-={}|:\"<>?`\[\];',./]{8,20}$"
)
+
+
class UserStatus(BaseCreatorDTO):
status: int = 1
-
-
+
class UserLanguage(BaseModel):
language: str
-
+
class BaseUser(BaseModel):
account: str = Field(min_length=1, max_length=100, description="用户账号")
oid: int
-
+
+
class BaseUserDTO(BaseUser, BaseCreatorDTO):
language: str = Field(pattern=r"^(zh-CN|en)$", default="zh-CN", description="用户语言")
password: str
status: int = 1
+
def to_dict(self):
return {
"id": self.id,
"account": self.account,
"oid": self.oid
}
+
@field_validator("language")
def validate_language(cls, lang: str) -> str:
if not re.fullmatch(r"^(zh-CN|en)$", lang):
raise ValueError("Language must be 'zh-CN' or 'en'")
return lang
+
class UserCreator(BaseUser):
name: str = Field(min_length=1, max_length=100, description="用户名")
email: str = Field(min_length=1, max_length=100, description="用户邮箱")
status: int = 1
- oid_list: Optional[list[int]] = None
-
+ oid_list: Optional[list[int]] = None
+
""" @field_validator("email")
def validate_email(cls, lang: str) -> str:
if not re.fullmatch(EMAIL_REGEX, lang):
raise ValueError("Email format is invalid!")
return lang """
+
class UserEditor(UserCreator, BaseCreatorDTO):
- pass
+ pass
+
class UserGrid(UserEditor):
create_time: int
language: str = "zh-CN"
- #space_name: Optional[str] = None
+ # space_name: Optional[str] = None
origin: str = ''
-
-
+
+
class PwdEditor(BaseModel):
pwd: str
new_pwd: str
-
+
+
class UserWsBase(BaseModel):
uid_list: list[int]
oid: Optional[int] = None
+
+
class UserWsDTO(UserWsBase):
weight: Optional[int] = 0
+
class UserWsEditor(BaseModel):
uid: int
- oid: int
+ oid: int
weight: int = 0
+
class UserInfoDTO(UserEditor):
language: str = "zh-CN"
weight: int = 0
isAdmin: bool = False
-
+
class AssistantBase(BaseModel):
name: str
domain: str
- type: int = 0
+ type: int = 0 # 0普通小助手 1高级 4页面嵌入
configuration: Optional[str] = None
description: Optional[str] = None
+
+
class AssistantDTO(AssistantBase, BaseCreatorDTO):
pass
+
class AssistantHeader(AssistantDTO):
unique: Optional[str] = None
certificate: Optional[str] = None
online: bool = False
-
+
class AssistantValidator(BaseModel):
valid: bool = False
id_match: bool = False
domain_match: bool = False
token: Optional[str] = None
-
+
def __init__(
- self,
- valid: bool = False,
- id_match: bool = False,
- domain_match: bool = False,
- token: Optional[str] = None,
- **kwargs
+ self,
+ valid: bool = False,
+ id_match: bool = False,
+ domain_match: bool = False,
+ token: Optional[str] = None,
+ **kwargs
):
super().__init__(
valid=valid,
@@ -122,23 +138,28 @@ def __init__(
token=token,
**kwargs
)
-
+
+
class WorkspaceUser(UserEditor):
weight: int
create_time: int
-
+
+
class UserWs(BaseCreatorDTO):
name: str
-
+
+
class UserWsOption(UserWs):
account: str
-
-
+
+
class AssistantFieldSchema(BaseModel):
id: Optional[int] = None
name: Optional[str] = None
type: Optional[str] = None
comment: Optional[str] = None
+
+
class AssistantTableSchema(BaseModel):
id: Optional[int] = None
name: Optional[str] = None
@@ -147,6 +168,7 @@ class AssistantTableSchema(BaseModel):
sql: Optional[str] = None
fields: Optional[list[AssistantFieldSchema]] = None
+
class AssistantOutDsBase(BaseModel):
id: Optional[int] = None
name: str
@@ -154,8 +176,8 @@ class AssistantOutDsBase(BaseModel):
type_name: Optional[str] = None
comment: Optional[str] = None
description: Optional[str] = None
-
-
+
+
class AssistantOutDsSchema(AssistantOutDsBase):
host: Optional[str] = None
port: Optional[int] = None
@@ -165,7 +187,8 @@ class AssistantOutDsSchema(AssistantOutDsBase):
db_schema: Optional[str] = None
extraParams: Optional[str] = None
tables: Optional[list[AssistantTableSchema]] = None
-
+
+
class AssistantUiSchema(BaseCreatorDTO):
theme: Optional[str] = None
header_font_color: Optional[str] = None
@@ -179,7 +202,3 @@ class AssistantUiSchema(BaseCreatorDTO):
name: Optional[str] = None
welcome: Optional[str] = None
welcome_desc: Optional[str] = None
-
-
-
-
\ No newline at end of file
diff --git a/backend/apps/template/generate_chart/generator.py b/backend/apps/template/generate_chart/generator.py
index 2c886277..a519b893 100644
--- a/backend/apps/template/generate_chart/generator.py
+++ b/backend/apps/template/generate_chart/generator.py
@@ -8,3 +8,7 @@ def get_chart_template():
def get_base_terminology_template():
template = get_base_template()
return template['template']['terminology']
+
+def get_base_data_training_template():
+ template = get_base_template()
+ return template['template']['data_training']
diff --git a/backend/apps/terminology/curd/terminology.py b/backend/apps/terminology/curd/terminology.py
index 309a2c82..355f483e 100644
--- a/backend/apps/terminology/curd/terminology.py
+++ b/backend/apps/terminology/curd/terminology.py
@@ -1,23 +1,23 @@
import datetime
import logging
import traceback
-from concurrent.futures import ThreadPoolExecutor
-from typing import List, Optional
+from typing import List, Optional, Any
from xml.dom.minidom import parseString
import dicttoxml
+from sqlalchemy import BigInteger
from sqlalchemy import and_, or_, select, func, delete, update, union
-from sqlalchemy import create_engine, text
+from sqlalchemy import text
from sqlalchemy.orm import aliased
-from sqlalchemy.orm import sessionmaker
+from sqlalchemy.orm.session import Session
from apps.ai_model.embedding import EmbeddingModelCache
+from apps.datasource.models.datasource import CoreDatasource
from apps.template.generate_chart.generator import get_base_terminology_template
from apps.terminology.models.terminology_model import Terminology, TerminologyInfo
from common.core.config import settings
from common.core.deps import SessionDep, Trans
-
-executor = ThreadPoolExecutor(max_workers=200)
+from common.utils.embedding_threads import run_save_terminology_embeddings
def page_terminology(session: SessionDep, current_page: int = 1, page_size: int = 10, name: Optional[str] = None,
@@ -82,6 +82,16 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int
.subquery()
)
+ # 创建子查询来获取数据源名称,添加类型转换
+ datasource_names_subquery = (
+ select(
+ func.jsonb_array_elements(Terminology.datasource_ids).cast(BigInteger).label('ds_id'),
+ Terminology.id.label('term_id')
+ )
+ .where(Terminology.id.in_(paginated_parent_ids))
+ .subquery()
+ )
+
# 主查询
stmt = (
select(
@@ -89,13 +99,34 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int
Terminology.word,
Terminology.create_time,
Terminology.description,
- children_subquery.c.other_words
+ Terminology.specific_ds,
+ Terminology.datasource_ids,
+ children_subquery.c.other_words,
+ func.jsonb_agg(CoreDatasource.name).filter(CoreDatasource.id.isnot(None)).label('datasource_names')
)
.outerjoin(
children_subquery,
Terminology.id == children_subquery.c.pid
)
+ # 关联数据源名称子查询和 CoreDatasource 表
+ .outerjoin(
+ datasource_names_subquery,
+ datasource_names_subquery.c.term_id == Terminology.id
+ )
+ .outerjoin(
+ CoreDatasource,
+ CoreDatasource.id == datasource_names_subquery.c.ds_id
+ )
.where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid))
+ .group_by(
+ Terminology.id,
+ Terminology.word,
+ Terminology.create_time,
+ Terminology.description,
+ Terminology.specific_ds,
+ Terminology.datasource_ids,
+ children_subquery.c.other_words
+ )
.order_by(Terminology.create_time.desc())
)
else:
@@ -118,17 +149,59 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int
.subquery()
)
+ children_subquery = (
+ select(
+ child.pid,
+ func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words')
+ )
+ .where(child.pid.isnot(None))
+ .group_by(child.pid)
+ .subquery()
+ )
+
+ # 创建子查询来获取数据源名称
+ datasource_names_subquery = (
+ select(
+ func.jsonb_array_elements(Terminology.datasource_ids).cast(BigInteger).label('ds_id'),
+ Terminology.id.label('term_id')
+ )
+ .where(Terminology.id.in_(paginated_parent_ids))
+ .subquery()
+ )
+
stmt = (
select(
Terminology.id,
Terminology.word,
Terminology.create_time,
Terminology.description,
- func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words')
+ Terminology.specific_ds,
+ Terminology.datasource_ids,
+ children_subquery.c.other_words,
+ func.jsonb_agg(CoreDatasource.name).filter(CoreDatasource.id.isnot(None)).label('datasource_names')
+ )
+ .outerjoin(
+ children_subquery,
+ Terminology.id == children_subquery.c.pid
+ )
+ # 关联数据源名称子查询和 CoreDatasource 表
+ .outerjoin(
+ datasource_names_subquery,
+ datasource_names_subquery.c.term_id == Terminology.id
+ )
+ .outerjoin(
+ CoreDatasource,
+ CoreDatasource.id == datasource_names_subquery.c.ds_id
)
- .outerjoin(child, and_(Terminology.id == child.pid))
.where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid))
- .group_by(Terminology.id, Terminology.word)
+ .group_by(Terminology.id,
+ Terminology.word,
+ Terminology.create_time,
+ Terminology.description,
+ Terminology.specific_ds,
+ Terminology.datasource_ids,
+ children_subquery.c.other_words
+ )
.order_by(Terminology.create_time.desc())
)
@@ -141,6 +214,9 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int
create_time=row.create_time,
description=row.description,
other_words=row.other_words if row.other_words else [],
+ specific_ds=row.specific_ds if row.specific_ds is not None else False,
+ datasource_ids=row.datasource_ids if row.datasource_ids is not None else [],
+ datasource_names=row.datasource_names if row.datasource_names is not None else [],
))
return current_page, page_size, total_count, total_pages, _list
@@ -148,7 +224,13 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int
def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans):
create_time = datetime.datetime.now()
- parent = Terminology(word=info.word, create_time=create_time, description=info.description, oid=oid)
+
+ specific_ds = info.specific_ds if info.specific_ds is not None else False
+ datasource_ids = info.datasource_ids if info.datasource_ids is not None else []
+
+ parent = Terminology(word=info.word, create_time=create_time, description=info.description, oid=oid,
+ specific_ds=specific_ds,
+ datasource_ids=datasource_ids)
words = [info.word]
for child in info.other_words:
@@ -177,13 +259,14 @@ def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra
if other_word.strip() == "":
continue
_list.append(
- Terminology(pid=result.id, word=other_word, create_time=create_time, oid=oid))
+ Terminology(pid=result.id, word=other_word, create_time=create_time, oid=oid,
+ specific_ds=specific_ds, datasource_ids=datasource_ids))
session.bulk_save_objects(_list)
session.flush()
session.commit()
# embedding
- run_save_embeddings([result.id])
+ run_save_terminology_embeddings([result.id])
return result.id
@@ -216,9 +299,14 @@ def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra
if exists:
raise Exception(trans("i18n_terminology.exists_in_db"))
+ specific_ds = info.specific_ds if info.specific_ds is not None else False
+ datasource_ids = info.datasource_ids if info.datasource_ids is not None else []
+
stmt = update(Terminology).where(and_(Terminology.id == info.id)).values(
word=info.word,
description=info.description,
+ specific_ds=specific_ds,
+ datasource_ids=datasource_ids
)
session.execute(stmt)
session.commit()
@@ -234,13 +322,14 @@ def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra
if other_word.strip() == "":
continue
_list.append(
- Terminology(pid=info.id, word=other_word, create_time=create_time, oid=oid))
+ Terminology(pid=info.id, word=other_word, create_time=create_time, oid=oid,
+ specific_ds=specific_ds, datasource_ids=datasource_ids))
session.bulk_save_objects(_list)
session.flush()
session.commit()
# embedding
- run_save_embeddings([info.id])
+ run_save_terminology_embeddings([info.id])
return info.id
@@ -251,37 +340,31 @@ def delete_terminology(session: SessionDep, ids: list[int]):
session.commit()
-def run_save_embeddings(ids: List[int]):
- executor.submit(save_embeddings, ids)
-
-
-def fill_empty_embeddings():
- executor.submit(run_fill_empty_embeddings)
+# def run_save_embeddings(ids: List[int]):
+# executor.submit(save_embeddings, ids)
+#
+#
+# def fill_empty_embeddings():
+# executor.submit(run_fill_empty_embeddings)
-def run_fill_empty_embeddings():
+def run_fill_empty_embeddings(session: Session):
if not settings.EMBEDDING_ENABLED:
return
- engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))
- session_maker = sessionmaker(bind=engine)
- session = session_maker()
stmt1 = select(Terminology.id).where(and_(Terminology.embedding.is_(None), Terminology.pid.is_(None)))
stmt2 = select(Terminology.pid).where(and_(Terminology.embedding.is_(None), Terminology.pid.isnot(None))).distinct()
combined_stmt = union(stmt1, stmt2)
results = session.execute(combined_stmt).scalars().all()
- save_embeddings(results)
+ save_embeddings(session, results)
-def save_embeddings(ids: List[int]):
+def save_embeddings(session: Session, ids: List[int]):
if not settings.EMBEDDING_ENABLED:
return
if not ids or len(ids) == 0:
return
try:
- engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI))
- session_maker = sessionmaker(bind=engine)
- session = session_maker()
_list = session.query(Terminology).filter(or_(Terminology.id.in_(ids), Terminology.pid.in_(ids))).all()
@@ -304,17 +387,22 @@ def save_embeddings(ids: List[int]):
embedding_sql = f"""
SELECT id, pid, word, similarity
FROM
-(SELECT id, pid, word, oid,
+(SELECT id, pid, word, oid, specific_ds, datasource_ids,
( 1 - (embedding <=> :embedding_array) ) AS similarity
FROM terminology AS child
) TEMP
-WHERE similarity > {settings.EMBEDDING_SIMILARITY} and oid = :oid
+WHERE similarity > {settings.EMBEDDING_TERMINOLOGY_SIMILARITY} AND oid = :oid
+AND (
+ (:datasource IS NULL AND (specific_ds = false OR specific_ds IS NULL))
+ OR
+ (:datasource IS NOT NULL AND ((specific_ds = false OR specific_ds IS NULL) OR (specific_ds = true AND datasource_ids IS NOT NULL AND datasource_ids @> jsonb_build_array(:datasource))))
+)
ORDER BY similarity DESC
-LIMIT {settings.EMBEDDING_TOP_COUNT}
+LIMIT {settings.EMBEDDING_TERMINOLOGY_TOP_COUNT}
"""
-def select_terminology_by_word(session: SessionDep, word: str, oid: int):
+def select_terminology_by_word(session: SessionDep, word: str, oid: int, datasource: int = None):
if word.strip() == "":
return []
@@ -331,7 +419,26 @@ def select_terminology_by_word(session: SessionDep, word: str, oid: int):
)
)
- results = session.execute(stmt, {'sentence': word}).fetchall()
+ if datasource is not None:
+ stmt = stmt.where(
+ or_(
+ or_(Terminology.specific_ds == False, Terminology.specific_ds.is_(None)),
+ and_(
+ Terminology.specific_ds == True,
+ Terminology.datasource_ids.isnot(None),
+ text("datasource_ids @> jsonb_build_array(:datasource)")
+ )
+ )
+ )
+ else:
+ stmt = stmt.where(or_(Terminology.specific_ds == False, Terminology.specific_ds.is_(None)))
+
+ # 执行查询
+ params: dict[str, Any] = {'sentence': word}
+ if datasource is not None:
+ params['datasource'] = datasource
+
+ results = session.execute(stmt, params).fetchall()
for row in results:
_list.append(Terminology(id=row.id, word=row.word, pid=row.pid))
@@ -342,7 +449,8 @@ def select_terminology_by_word(session: SessionDep, word: str, oid: int):
embedding = model.embed_query(word)
- results = session.execute(text(embedding_sql), {'embedding_array': str(embedding), 'oid': oid})
+ results = session.execute(text(embedding_sql), {'embedding_array': str(embedding), 'oid': oid,
+ 'datasource': datasource}).fetchall()
for row in results:
_list.append(Terminology(id=row.id, word=row.word, pid=row.pid))
@@ -418,10 +526,11 @@ def to_xml_string(_dict: list[dict] | dict, root: str = 'terminologies') -> str:
return pretty_xml
-def get_terminology_template(session: SessionDep, question: str, oid: Optional[int] = 1) -> str:
+def get_terminology_template(session: SessionDep, question: str, oid: Optional[int] = 1,
+ datasource: Optional[int] = None) -> str:
if not oid:
oid = 1
- _results = select_terminology_by_word(session, question, oid)
+ _results = select_terminology_by_word(session, question, oid, datasource)
if _results and len(_results) > 0:
terminology = to_xml_string(_results)
template = get_base_terminology_template().format(terminologies=terminology)
diff --git a/backend/apps/terminology/models/terminology_model.py b/backend/apps/terminology/models/terminology_model.py
index 57c35ce4..b9048659 100644
--- a/backend/apps/terminology/models/terminology_model.py
+++ b/backend/apps/terminology/models/terminology_model.py
@@ -3,7 +3,8 @@
from pgvector.sqlalchemy import VECTOR
from pydantic import BaseModel
-from sqlalchemy import Column, Text, BigInteger, DateTime, Identity
+from sqlalchemy import Column, Text, BigInteger, DateTime, Identity, Boolean
+from sqlalchemy.dialects.postgresql import JSONB
from sqlmodel import SQLModel, Field
@@ -16,6 +17,8 @@ class Terminology(SQLModel, table=True):
word: Optional[str] = Field(max_length=255)
description: Optional[str] = Field(sa_column=Column(Text, nullable=True))
embedding: Optional[List[float]] = Field(sa_column=Column(VECTOR(), nullable=True))
+ specific_ds: Optional[bool] = Field(sa_column=Column(Boolean, default=False))
+ datasource_ids: Optional[list[int]] = Field(sa_column=Column(JSONB), default=[])
class TerminologyInfo(BaseModel):
@@ -24,5 +27,6 @@ class TerminologyInfo(BaseModel):
word: Optional[str] = None
description: Optional[str] = None
other_words: Optional[List[str]] = []
-
-
+ specific_ds: Optional[bool] = False
+ datasource_ids: Optional[list[int]] = []
+ datasource_names: Optional[list[str]] = []
diff --git a/backend/common/core/config.py b/backend/common/core/config.py
index 6a922ae9..ca0468b8 100644
--- a/backend/common/core/config.py
+++ b/backend/common/core/config.py
@@ -89,13 +89,24 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str:
LOCAL_MODEL_PATH: str = '/opt/sqlbot/models'
DEFAULT_EMBEDDING_MODEL: str = 'shibing624/text2vec-base-chinese'
EMBEDDING_ENABLED: bool = True
- EMBEDDING_SIMILARITY: float = 0.4
- EMBEDDING_TOP_COUNT: int = 5
+ EMBEDDING_DEFAULT_SIMILARITY: float = 0.4
+ EMBEDDING_TERMINOLOGY_SIMILARITY: float = EMBEDDING_DEFAULT_SIMILARITY
+ EMBEDDING_DATA_TRAINING_SIMILARITY: float = EMBEDDING_DEFAULT_SIMILARITY
+ EMBEDDING_DEFAULT_TOP_COUNT: int = 5
+ EMBEDDING_TERMINOLOGY_TOP_COUNT: int = EMBEDDING_DEFAULT_TOP_COUNT
+ EMBEDDING_DATA_TRAINING_TOP_COUNT: int = EMBEDDING_DEFAULT_TOP_COUNT
+
+ PARSE_REASONING_BLOCK_ENABLED: bool = True
+ DEFAULT_REASONING_CONTENT_START: str = '