diff --git a/.github/workflows/package_and_push.yml b/.github/workflows/package_and_push.yml index 3ec89e7c..f946d3c4 100644 --- a/.github/workflows/package_and_push.yml +++ b/.github/workflows/package_and_push.yml @@ -9,6 +9,11 @@ on: description: 'Image Tag' default: 'v0.9.5' required: true + packageEE: + description: '是否打包企业版' + default: false + required: true + type: boolean jobs: package-and-push-to-aliyun-oss: @@ -31,6 +36,7 @@ jobs: ALIYUN_OSS_BUCKET_ENDPOINT: ${{ secrets.ALIYUN_OSS_BUCKET_ENDPOINT }} ALIYUN_OSS_ACCESS_KEY: ${{ secrets.ALIYUN_OSS_ACCESS_KEY }} ALIYUN_OSS_ACCESS_SECRET: ${{ secrets.ALIYUN_OSS_ACCESS_SECRET }} + PACKAGE_EE: ${{ github.event.inputs.packageEE }} run: | DOCKER_IMAGE=${ALIYUN_REGISTRY_HOST}/dataease/sqlbot cd installer @@ -76,12 +82,26 @@ jobs: --exclude .git \ --exclude images \ --exclude docker \ - -czvf $package_online . + -czvf $package_offline . + + if [[ ${PACKAGE_EE} == 'true' ]]; then + package_offline_ee="sqlbot-offline-installer-${TAG_NAME}${platform}-ee.tar.gz" + touch $package_offline_ee + tar --transform "s/^\./sqlbot-offline-installer-${TAG_NAME}${platform}-ee/" \ + --exclude $package_online \ + --exclude $package_offline \ + --exclude $package_offline_ee \ + --exclude .git \ + -czvf $package_offline_ee . + + ossutil cp -rf ${package_offline_ee} oss://$ALIYUN_OSS_BUCKET/sqlbot/${package_offline_ee} --access-key-id=$ALIYUN_OSS_ACCESS_KEY --access-key-secret=$ALIYUN_OSS_ACCESS_SECRET --endpoint=${ALIYUN_OSS_BUCKET_ENDPOINT} + fi #Sync files to OSS - ossutil cp -rf ${package_offline} oss://$ALIYUN_OSS_BUCKET/sqlbot/${package_offline} --access-key-id=$ALIYUN_OSS_ACCESS_KEY --access-key-secret=$ALIYUN_OSS_ACCESS_SECRET --endpoint=${ALIYUN_OSS_BUCKET_ENDPOINT} + ossutil cp -rf ${package_offline} oss://$ALIYUN_OSS_BUCKET/sqlbot/${package_offline} --access-key-id=$ALIYUN_OSS_ACCESS_KEY --access-key-secret=$ALIYUN_OSS_ACCESS_SECRET --endpoint=${ALIYUN_OSS_BUCKET_ENDPOINT} ossutil cp -rf ${package_online} oss://$ALIYUN_OSS_BUCKET/sqlbot/${package_online} --access-key-id=$ALIYUN_OSS_ACCESS_KEY --access-key-secret=$ALIYUN_OSS_ACCESS_SECRET --endpoint=${ALIYUN_OSS_BUCKET_ENDPOINT} + diff --git a/README.md b/README.md index 98e00c88..d3622ab9 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ SQLBot 是一款基于大模型和 RAG 的智能问数系统。SQLBot 的优势 ## 工作原理 -system-arch +system-arch ## 快速开始 @@ -31,8 +31,9 @@ docker run -d \ -p 8000:8000 \ -p 8001:8001 \ -v ./data/sqlbot/excel:/opt/sqlbot/data/excel \ + -v ./data/sqlbot/file:/opt/sqlbot/data/file \ -v ./data/sqlbot/images:/opt/sqlbot/images \ - -v ./data/sqlbot/logs:/opt/sqlbot/logs \ + -v ./data/sqlbot/logs:/opt/sqlbot/app/logs \ -v ./data/postgresql:/var/lib/postgresql/data \ --privileged=true \ dataease/sqlbot @@ -70,6 +71,7 @@ docker run -d \ - [1Panel](https://github.com/1panel-dev/1panel/) - 现代化、开源的 Linux 服务器运维管理面板 - [MaxKB](https://github.com/1panel-dev/MaxKB/) - 强大易用的企业级智能体平台 - [JumpServer](https://github.com/jumpserver/jumpserver/) - 广受欢迎的开源堡垒机 +- [Cordys CRM](https://github.com/1Panel-dev/CordysCRM) - 新一代的开源 AI CRM 系统 - [Halo](https://github.com/halo-dev/halo/) - 强大易用的开源建站工具 - [MeterSphere](https://github.com/metersphere/metersphere/) - 新一代的开源持续测试工具 diff --git a/backend/alembic/env.py b/backend/alembic/env.py index 3f72963a..79dfbdc6 100755 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -26,8 +26,10 @@ # from apps.settings.models.setting_models import SQLModel # from apps.chat.models.chat_model import SQLModel from apps.terminology.models.terminology_model import SQLModel +# from apps.data_training.models.data_training_model import SQLModel # from apps.dashboard.models.dashboard_model import SQLModel from common.core.config import settings # noqa +#from apps.datasource.models.datasource import SQLModel target_metadata = SQLModel.metadata diff --git a/backend/alembic/versions/042_data_training.py b/backend/alembic/versions/042_data_training.py new file mode 100644 index 00000000..ec44c89f --- /dev/null +++ b/backend/alembic/versions/042_data_training.py @@ -0,0 +1,41 @@ +"""042_data_training + +Revision ID: a487d9c69341 +Revises: c4c3c36b720d +Create Date: 2025-09-15 15:41:43.332771 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from sqlalchemy.dialects import postgresql +import pgvector + +# revision identifiers, used by Alembic. +revision = 'a487d9c69341' +down_revision = 'c4c3c36b720d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('data_training', + sa.Column('id', sa.BigInteger(), sa.Identity(always=True), nullable=False), + sa.Column('oid', sa.BigInteger(), nullable=True), + sa.Column('datasource', sa.BigInteger(), nullable=True), + sa.Column('create_time', sa.DateTime(), nullable=True), + sa.Column('question', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=True), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('embedding', pgvector.sqlalchemy.vector.VECTOR(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + op.drop_table('data_training') + # ### end Alembic commands ### diff --git a/backend/alembic/versions/043_modify_ds_id_type.py b/backend/alembic/versions/043_modify_ds_id_type.py new file mode 100644 index 00000000..4ef303a0 --- /dev/null +++ b/backend/alembic/versions/043_modify_ds_id_type.py @@ -0,0 +1,55 @@ +"""043_modify_ds_id_type + +Revision ID: dac062c1f7b1 +Revises: a487d9c69341 +Create Date: 2025-09-22 17:20:44.465735 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'dac062c1f7b1' +down_revision = 'a487d9c69341' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('core_datasource', 'id', + existing_type=sa.INTEGER(), + type_=sa.BigInteger(), + existing_nullable=False, + autoincrement=True) + op.alter_column('core_field', 'id', + existing_type=sa.INTEGER(), + type_=sa.BigInteger(), + existing_nullable=False, + autoincrement=True) + op.alter_column('core_table', 'id', + existing_type=sa.INTEGER(), + type_=sa.BigInteger(), + existing_nullable=False, + autoincrement=True) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('core_table', 'id', + existing_type=sa.BigInteger(), + type_=sa.INTEGER(), + existing_nullable=False, + autoincrement=True) + op.alter_column('core_field', 'id', + existing_type=sa.BigInteger(), + type_=sa.INTEGER(), + existing_nullable=False, + autoincrement=True) + op.alter_column('core_datasource', 'id', + existing_type=sa.BigInteger(), + type_=sa.INTEGER(), + existing_nullable=False, + autoincrement=True) + # ### end Alembic commands ### diff --git a/backend/alembic/versions/044_table_relation.py b/backend/alembic/versions/044_table_relation.py new file mode 100644 index 00000000..9f655ed4 --- /dev/null +++ b/backend/alembic/versions/044_table_relation.py @@ -0,0 +1,29 @@ +"""044_table_relation + +Revision ID: 455b8ce69e80 +Revises: dac062c1f7b1 +Create Date: 2025-09-24 13:34:08.205659 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '455b8ce69e80' +down_revision = 'dac062c1f7b1' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('core_datasource', sa.Column('table_relation', postgresql.JSONB(astext_type=sa.Text()), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('core_datasource', 'table_relation') + # ### end Alembic commands ### diff --git a/backend/alembic/versions/045_modify_terminolog.py b/backend/alembic/versions/045_modify_terminolog.py new file mode 100644 index 00000000..a452bb6c --- /dev/null +++ b/backend/alembic/versions/045_modify_terminolog.py @@ -0,0 +1,31 @@ +"""045_modify_terminolog + +Revision ID: 45e7e52bf2b8 +Revises: 455b8ce69e80 +Create Date: 2025-09-25 14:49:24.521795 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '45e7e52bf2b8' +down_revision = '455b8ce69e80' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('terminology', sa.Column('specific_ds', sa.Boolean(), nullable=True)) + op.add_column('terminology', sa.Column('datasource_ids', postgresql.JSONB(astext_type=sa.Text()), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('terminology', 'datasource_ids') + op.drop_column('terminology', 'specific_ds') + # ### end Alembic commands ### diff --git a/backend/apps/api.py b/backend/apps/api.py index 14b07cea..8b836c0d 100644 --- a/backend/apps/api.py +++ b/backend/apps/api.py @@ -1,11 +1,12 @@ from fastapi import APIRouter -from apps.terminology.api import terminology from apps.chat.api import chat from apps.dashboard.api import dashboard_api -from apps.datasource.api import datasource -from apps.system.api import login, user, aimodel, workspace, assistant +from apps.data_training.api import data_training +from apps.datasource.api import datasource, table_relation from apps.mcp import mcp +from apps.system.api import login, user, aimodel, workspace, assistant +from apps.terminology.api import terminology api_router = APIRouter() api_router.include_router(login.router) @@ -14,9 +15,9 @@ api_router.include_router(assistant.router) api_router.include_router(aimodel.router) api_router.include_router(terminology.router) +api_router.include_router(data_training.router) api_router.include_router(datasource.router) api_router.include_router(chat.router) api_router.include_router(dashboard_api.router) api_router.include_router(mcp.router) - - +api_router.include_router(table_relation.router) diff --git a/backend/apps/chat/api/chat.py b/backend/apps/chat/api/chat.py index fa85d646..a28c03d7 100644 --- a/backend/apps/chat/api/chat.py +++ b/backend/apps/chat/api/chat.py @@ -13,7 +13,7 @@ delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data, get_chat_record_by_id from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, ExcelData from apps.chat.task.llm import LLMService -from common.core.deps import CurrentAssistant, SessionDep, CurrentUser +from common.core.deps import CurrentAssistant, SessionDep, CurrentUser, Trans router = APIRouter(tags=["Data Q&A"], prefix="/chat") @@ -105,14 +105,15 @@ async def start_chat(session: SessionDep, current_user: CurrentUser): @router.post("/recommend_questions/{chat_record_id}") async def recommend_questions(session: SessionDep, current_user: CurrentUser, chat_record_id: int, current_assistant: CurrentAssistant): + def _return_empty(): + yield 'data:' + orjson.dumps({'content': '[]', 'type': 'recommended_question'}).decode() + '\n\n' + try: record = get_chat_record_by_id(session, chat_record_id) if not record: - raise HTTPException( - status_code=400, - detail=f"Chat record with id {chat_record_id} not found" - ) + return StreamingResponse(_return_empty(), media_type="text/event-stream") + request_question = ChatQuestion(chat_id=record.chat_id, question=record.question if record.question else '') llm_service = await LLMService.create(current_user, request_question, current_assistant, True) @@ -120,10 +121,11 @@ async def recommend_questions(session: SessionDep, current_user: CurrentUser, ch llm_service.run_recommend_questions_task_async() except Exception as e: traceback.print_exc() - raise HTTPException( - status_code=500, - detail=str(e) - ) + + def _err(_e: Exception): + yield 'data:' + orjson.dumps({'content': str(_e), 'type': 'error'}).decode() + '\n\n' + + return StreamingResponse(_err(e), media_type="text/event-stream") return StreamingResponse(llm_service.await_result(), media_type="text/event-stream") @@ -143,7 +145,7 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que """ try: - llm_service = await LLMService.create(current_user, request_question, current_assistant) + llm_service = await LLMService.create(current_user, request_question, current_assistant, embedding=True) llm_service.init_record() llm_service.run_task_async() except Exception as e: @@ -199,10 +201,16 @@ def _err(_e: Exception): @router.post("/excel/export") -async def export_excel(excel_data: ExcelData): +async def export_excel(excel_data: ExcelData, trans: Trans): def inner(): _fields_list = [] data = [] + if not excel_data.data: + raise HTTPException( + status_code=500, + detail=trans("i18n_excel_export.data_is_empty") + ) + for _data in excel_data.data: _row = [] for field in excel_data.axis: diff --git a/backend/apps/chat/models/chat_model.py b/backend/apps/chat/models/chat_model.py index eb418251..6cc30b83 100644 --- a/backend/apps/chat/models/chat_model.py +++ b/backend/apps/chat/models/chat_model.py @@ -40,6 +40,10 @@ class OperationEnum(Enum): CHOOSE_DATASOURCE = '6' GENERATE_DYNAMIC_SQL = '7' +class ChatFinishStep(Enum): + GENERATE_SQL = 1 + QUERY_DATA = 2 + GENERATE_CHART = 3 # TODO choose table / check connection / generate description @@ -136,7 +140,7 @@ class CreateChat(BaseModel): id: int = None question: str = None datasource: int = None - origin: Optional[int] = 0 + origin: Optional[int] = 0 # 0是页面上,mcp是1,小助手是2 class RenameChat(BaseModel): @@ -172,11 +176,13 @@ class AiModelQuestion(BaseModel): filter: str = [] sub_query: Optional[list[dict]] = None terminologies: str = "" + data_training: str = "" error_msg: str = "" def sql_sys_question(self): return get_sql_template()['system'].format(engine=self.engine, schema=self.db_schema, question=self.question, - lang=self.lang, terminologies=self.terminologies) + lang=self.lang, terminologies=self.terminologies, + data_training=self.data_training) def sql_user_question(self, current_time: str): return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=self.question, @@ -244,6 +250,7 @@ class McpQuestion(BaseModel): question: str = Body(description='用户提问') chat_id: int = Body(description='会话ID') token: str = Body(description='token') + stream: Optional[bool] = Body(description='是否流式输出,默认为true开启, 关闭false则返回JSON对象', default=True) class AxisObj(BaseModel): @@ -256,3 +263,10 @@ class ExcelData(BaseModel): axis: list[AxisObj] = [] data: list[dict] = [] name: str = 'Excel' + + +class McpAssistant(BaseModel): + question: str = Body(description='用户提问') + url: str = Body(description='第三方数据接口') + authorization: str = Body(description='第三方接口凭证') + stream: Optional[bool] = Body(description='是否流式输出,默认为true开启, 关闭false则返回JSON对象', default=True) diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index 8cc1f470..a30a741d 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -6,7 +6,7 @@ import warnings from concurrent.futures import ThreadPoolExecutor, Future from datetime import datetime -from typing import Any, List, Optional, Union, Dict +from typing import Any, List, Optional, Union, Dict, Iterator import numpy as np import orjson @@ -16,7 +16,7 @@ from langchain.chat_models.base import BaseChatModel from langchain_community.utilities import SQLDatabase from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage, BaseMessageChunk -from sqlalchemy import select +from sqlalchemy import and_, select from sqlalchemy.orm import sessionmaker from sqlmodel import Session @@ -28,9 +28,12 @@ get_old_questions, save_analysis_predict_record, rename_chat, get_chart_config, \ get_chat_chart_data, list_generate_sql_logs, list_generate_chart_logs, start_log, end_log, \ get_last_execute_sql_error -from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum +from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \ + ChatFinishStep +from apps.data_training.curd.data_training import get_training_template from apps.datasource.crud.datasource import get_table_schema from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user +from apps.datasource.embedding.ds_embedding import get_ds_embedding from apps.datasource.models.datasource import CoreDatasource from apps.db.db import exec_sql, get_version, check_connection from apps.system.crud.assistant import AssistantOutDs, AssistantOutDsFactory, get_assistant_ds @@ -82,7 +85,7 @@ class LLMService: def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion, current_assistant: Optional[CurrentAssistant] = None, no_reasoning: bool = False, - config: LLMConfig = None): + embedding: bool = False, config: LLMConfig = None): self.chunk_list = [] # engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI)) # session_maker = sessionmaker(bind=engine) @@ -111,7 +114,8 @@ def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion, if not ds: raise SingleMessageError("No available datasource configuration found") chat_question.engine = (ds.type_name if ds.type != 'excel' else 'PostgreSQL') + get_version(ds) - chat_question.db_schema = get_table_schema(session=self.session, current_user=current_user, ds=ds) + chat_question.db_schema = get_table_schema(session=self.session, current_user=current_user, ds=ds, + question=chat_question.question, embedding=embedding) self.generate_sql_logs = list_generate_sql_logs(session=self.session, chart_id=chat_id) self.generate_chart_logs = list_generate_chart_logs(session=self.session, chart_id=chat_id) @@ -146,8 +150,6 @@ def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion, else: self.chat_question.error_msg = '' - self.init_messages() - @classmethod async def create(cls, *args, **kwargs): config: LLMConfig = await get_default_config() @@ -239,8 +241,9 @@ def generate_analysis(self): self.chat_question.data = orjson.dumps(data.get('data')).decode() analysis_msg: List[Union[BaseMessage, dict[str, Any]]] = [] + ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question, - self.current_user.oid) + self.current_user.oid, ds_id) analysis_msg.append(SystemMessage(content=self.chat_question.analysis_sys_question())) analysis_msg.append(HumanMessage(content=self.chat_question.analysis_user_question())) @@ -257,22 +260,14 @@ def generate_analysis(self): in analysis_msg]) full_thinking_text = '' full_analysis_text = '' - res = self.llm.stream(analysis_msg) token_usage = {} + res = process_stream(self.llm.stream(analysis_msg), token_usage) for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - # else: - # reasoning_content_chunk = chunk.get('reasoning_content') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - - full_analysis_text += chunk.content - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} - get_token_usage(chunk, token_usage) + if chunk.get('content'): + full_analysis_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') + yield chunk analysis_msg.append(AIMessage(full_analysis_text)) @@ -309,22 +304,14 @@ def generate_predict(self): in predict_msg]) full_thinking_text = '' full_predict_text = '' - res = self.llm.stream(predict_msg) token_usage = {} + res = process_stream(self.llm.stream(predict_msg), token_usage) for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - # else: - # reasoning_content_chunk = chunk.get('reasoning_content') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - - full_predict_text += chunk.content - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} - get_token_usage(chunk, token_usage) + if chunk.get('content'): + full_predict_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') + yield chunk predict_msg.append(AIMessage(full_predict_text)) self.record = save_predict_answer(session=self.session, record_id=self.record.id, @@ -345,7 +332,9 @@ def generate_recommend_questions_task(self): if self.ds and not self.chat_question.db_schema: self.chat_question.db_schema = self.out_ds_instance.get_db_schema( self.ds.id) if self.out_ds_instance else get_table_schema(session=self.session, - current_user=self.current_user, ds=self.ds) + current_user=self.current_user, ds=self.ds, + question=self.chat_question.question, + embedding=False) guess_msg: List[Union[BaseMessage, dict[str, Any]]] = [] guess_msg.append(SystemMessage(content=self.chat_question.guess_sys_question())) @@ -367,21 +356,13 @@ def generate_recommend_questions_task(self): full_thinking_text = '' full_guess_text = '' token_usage = {} - res = self.llm.stream(guess_msg) + res = process_stream(self.llm.stream(guess_msg), token_usage) for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - # else: - # reasoning_content_chunk = chunk.get('reasoning_content') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - - full_guess_text += chunk.content - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} - get_token_usage(chunk, token_usage) + if chunk.get('content'): + full_guess_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') + yield chunk guess_msg.append(AIMessage(full_guess_text)) @@ -405,9 +386,8 @@ def select_datasource(self): if self.current_assistant and self.current_assistant.type != 4: _ds_list = get_assistant_ds(session=self.session, llm_service=self) else: - oid: str = self.current_user.oid stmt = select(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description).where( - CoreDatasource.oid == oid) + and_(CoreDatasource.oid == self.current_user.oid)) _ds_list = [ { "id": ds.id, @@ -425,57 +405,58 @@ def select_datasource(self): full_thinking_text = '' full_text = '' - if not ignore_auto_select: - _ds_list_dict = [] - for _ds in _ds_list: - _ds_list_dict.append(_ds) - datasource_msg.append( - HumanMessage(self.chat_question.datasource_user_question(orjson.dumps(_ds_list_dict).decode()))) - - self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = start_log(session=self.session, - ai_modal_id=self.chat_question.ai_modal_id, - ai_modal_name=self.chat_question.ai_modal_name, - operate=OperationEnum.CHOOSE_DATASOURCE, - record_id=self.record.id, - full_message=[{'type': msg.type, - 'content': msg.content} for - msg in datasource_msg]) - - token_usage = {} - res = self.llm.stream(datasource_msg) - for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - # else: - # reasoning_content_chunk = chunk.get('reasoning_content') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - - full_text += chunk.content - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} - get_token_usage(chunk, token_usage) - datasource_msg.append(AIMessage(full_text)) - - self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=self.session, - log=self.current_logs[ - OperationEnum.CHOOSE_DATASOURCE], - full_message=[ - {'type': msg.type, 'content': msg.content} - for msg in datasource_msg], - reasoning_content=full_thinking_text, - token_usage=token_usage) - - json_str = extract_nested_json(full_text) + if settings.TABLE_EMBEDDING_ENABLED: + ds = get_ds_embedding(self.session, self.current_user, _ds_list, self.out_ds_instance, + self.chat_question.question, self.current_assistant) + yield {'content': '{"id":' + str(ds.get('id')) + '}'} + else: + _ds_list_dict = [] + for _ds in _ds_list: + _ds_list_dict.append(_ds) + datasource_msg.append( + HumanMessage(self.chat_question.datasource_user_question(orjson.dumps(_ds_list_dict).decode()))) + + self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = start_log(session=self.session, + ai_modal_id=self.chat_question.ai_modal_id, + ai_modal_name=self.chat_question.ai_modal_name, + operate=OperationEnum.CHOOSE_DATASOURCE, + record_id=self.record.id, + full_message=[{'type': msg.type, + 'content': msg.content} + for + msg in datasource_msg]) + + token_usage = {} + res = process_stream(self.llm.stream(datasource_msg), token_usage) + for chunk in res: + if chunk.get('content'): + full_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') + yield chunk + datasource_msg.append(AIMessage(full_text)) + + self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=self.session, + log=self.current_logs[ + OperationEnum.CHOOSE_DATASOURCE], + full_message=[ + {'type': msg.type, + 'content': msg.content} + for msg in datasource_msg], + reasoning_content=full_thinking_text, + token_usage=token_usage) + + json_str = extract_nested_json(full_text) + if json_str is None: + raise SingleMessageError(f'Cannot parse datasource from answer: {full_text}') + ds = orjson.loads(json_str) _error: Exception | None = None _datasource: int | None = None _engine_type: str | None = None try: - data: dict = _ds_list[0] if ignore_auto_select else orjson.loads(json_str) + data: dict = _ds_list[0] if ignore_auto_select else ds if data.get('id') and data.get('id') != 0: _datasource = data['id'] @@ -497,7 +478,8 @@ def select_datasource(self): self.chat_question.engine = (_ds.type_name if _ds.type != 'excel' else 'PostgreSQL') + get_version( self.ds) self.chat_question.db_schema = get_table_schema(session=self.session, - current_user=self.current_user, ds=self.ds) + current_user=self.current_user, ds=self.ds, + question=self.chat_question.question) _engine_type = self.chat_question.engine _chat.engine_type = _ds.type_name # save chat @@ -514,16 +496,21 @@ def select_datasource(self): except Exception as e: _error = e - if not ignore_auto_select: + if not ignore_auto_select and not settings.TABLE_EMBEDDING_ENABLED: self.record = save_select_datasource_answer(session=self.session, record_id=self.record.id, answer=orjson.dumps({'content': full_text}).decode(), datasource=_datasource, engine_type=_engine_type) + if self.ds: + oid = self.ds.oid if isinstance(self.ds, CoreDatasource) else 1 + ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None - self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question, - self.ds.oid if isinstance(self.ds, - CoreDatasource) else 1) - self.init_messages() + self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question, oid, + ds_id) + self.chat_question.data_training = get_training_template(self.session, self.chat_question.question, ds_id, + oid) + + self.init_messages() if _error: raise _error @@ -544,21 +531,13 @@ def generate_sql(self): full_thinking_text = '' full_sql_text = '' token_usage = {} - res = self.llm.stream(self.sql_message) + res = process_stream(self.llm.stream(self.sql_message), token_usage) for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - # else: - # reasoning_content_chunk = chunk.get('reasoning_content') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - - full_sql_text += chunk.content - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} - get_token_usage(chunk, token_usage) + if chunk.get('content'): + full_sql_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') + yield chunk self.sql_message.append(AIMessage(full_sql_text)) @@ -591,18 +570,14 @@ def generate_with_sub_sql(self, sql, sub_mappings: list): full_thinking_text = '' full_dynamic_text = '' - res = self.llm.stream(dynamic_sql_msg) token_usage = {} + res = process_stream(self.llm.stream(dynamic_sql_msg), token_usage) for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - full_dynamic_text += chunk.content - get_token_usage(chunk, token_usage) + if chunk.get('content'): + full_dynamic_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') + yield chunk dynamic_sql_msg.append(AIMessage(full_dynamic_text)) @@ -654,22 +629,13 @@ def build_table_filter(self, sql: str, filters: list): in permission_sql_msg]) full_thinking_text = '' full_filter_text = '' - res = self.llm.stream(permission_sql_msg) token_usage = {} + res = process_stream(self.llm.stream(permission_sql_msg), token_usage) for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - # else: - # reasoning_content_chunk = chunk.get('reasoning_content') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - - full_filter_text += chunk.content - # yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} - get_token_usage(chunk, token_usage) + if chunk.get('content'): + full_filter_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') permission_sql_msg.append(AIMessage(full_filter_text)) @@ -719,21 +685,13 @@ def generate_chart(self, chart_type: Optional[str] = ''): full_thinking_text = '' full_chart_text = '' token_usage = {} - res = self.llm.stream(self.chart_message) + res = process_stream(self.llm.stream(self.chart_message), token_usage) for chunk in res: - SQLBotLogUtil.info(chunk) - reasoning_content_chunk = '' - if 'reasoning_content' in chunk.additional_kwargs: - reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') - # else: - # reasoning_content_chunk = chunk.get('reasoning_content') - if reasoning_content_chunk is None: - reasoning_content_chunk = '' - full_thinking_text += reasoning_content_chunk - - full_chart_text += chunk.content - yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk} - get_token_usage(chunk, token_usage) + if chunk.get('content'): + full_chart_text += chunk.get('content') + if chunk.get('reasoning_content'): + full_thinking_text += chunk.get('reasoning_content') + yield chunk self.chart_message.append(AIMessage(full_chart_text)) @@ -881,7 +839,7 @@ def save_sql_data(self, data_obj: Dict[str, Any]): def finish(self): return finish_record(session=self.session, record_id=self.record.id) - def execute_sql(self, sql: str, tables): + def execute_sql(self, sql: str): """Execute SQL query Args: @@ -893,7 +851,7 @@ def execute_sql(self, sql: str, tables): """ SQLBotLogUtil.info(f"Executing SQL on ds_id {self.ds.id}: {sql}") try: - return exec_sql(ds=self.ds, sql=sql, origin_column=False, table_name=tables) + return exec_sql(ds=self.ds, sql=sql, origin_column=False) except Exception as e: if isinstance(e, ParseSQLResultError): raise e @@ -922,24 +880,36 @@ def await_result(self): break yield chunk - def run_task_async(self, in_chat: bool = True): - self.future = executor.submit(self.run_task_cache, in_chat) + def run_task_async(self, in_chat: bool = True, stream: bool = True, + finish_step: ChatFinishStep = ChatFinishStep.GENERATE_CHART): + if in_chat: + stream = True + self.future = executor.submit(self.run_task_cache, in_chat, stream, finish_step) - def run_task_cache(self, in_chat: bool = True): - for chunk in self.run_task(in_chat): + def run_task_cache(self, in_chat: bool = True, stream: bool = True, + finish_step: ChatFinishStep = ChatFinishStep.GENERATE_CHART): + for chunk in self.run_task(in_chat, stream, finish_step): self.chunk_list.append(chunk) - def run_task(self, in_chat: bool = True): + def run_task(self, in_chat: bool = True, stream: bool = True, + finish_step: ChatFinishStep = ChatFinishStep.GENERATE_CHART): + json_result: Dict[str, Any] = {'success': True} try: if self.ds: + oid = self.ds.oid if isinstance(self.ds, CoreDatasource) else 1 + ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question, - self.ds.oid if isinstance(self.ds, - CoreDatasource) else 1) + oid, ds_id) + self.chat_question.data_training = get_training_template(self.session, self.chat_question.question, + ds_id, oid) + self.init_messages() # return id if in_chat: yield 'data:' + orjson.dumps({'type': 'id', 'id': self.get_record().id}).decode() + '\n\n' + if not stream: + json_result['record_id'] = self.get_record().id # return title if self.change_title: @@ -949,8 +919,10 @@ def run_task(self, in_chat: bool = True): brief=self.chat_question.question.strip()[:20])) if in_chat: yield 'data:' + orjson.dumps({'type': 'brief', 'brief': brief}).decode() + '\n\n' + if not stream: + json_result['title'] = brief - # select datasource if datasource is none + # select datasource if datasource is none if not self.ds: ds_res = self.select_datasource() @@ -968,7 +940,8 @@ def run_task(self, in_chat: bool = True): self.chat_question.db_schema = self.out_ds_instance.get_db_schema( self.ds.id) if self.out_ds_instance else get_table_schema(session=self.session, current_user=self.current_user, - ds=self.ds) + ds=self.ds, + question=self.chat_question.question) else: self.validate_history_ds() @@ -988,7 +961,6 @@ def run_task(self, in_chat: bool = True): 'type': 'sql-result'}).decode() + '\n\n' if in_chat: yield 'data:' + orjson.dumps({'type': 'info', 'msg': 'sql generated'}).decode() + '\n\n' - # filter sql SQLBotLogUtil.info(full_sql_text) @@ -1022,14 +994,18 @@ def run_task(self, in_chat: bool = True): sql = self.check_save_sql(res=full_sql_text) else: sql = self.check_save_sql(res=full_sql_text) - tables = [] - SQLBotLogUtil.info(sql) + SQLBotLogUtil.info('sql: ' + sql) + + if not stream: + json_result['sql'] = sql + format_sql = sqlparse.format(sql, reindent=True) if in_chat: yield 'data:' + orjson.dumps({'content': format_sql, 'type': 'sql'}).decode() + '\n\n' else: - yield f'```sql\n{format_sql}\n```\n\n' + if stream: + yield f'```sql\n{format_sql}\n```\n\n' # execute sql real_execute_sql = sql @@ -1040,11 +1016,46 @@ def run_task(self, in_chat: bool = True): subsql) real_execute_sql = assistant_dynamic_sql - result = self.execute_sql(sql=real_execute_sql, tables=tables) - print(result) + if finish_step.value <= ChatFinishStep.GENERATE_SQL.value: + if in_chat: + yield 'data:' + orjson.dumps({'type': 'finish'}).decode() + '\n\n' + if not stream: + yield json_result + return + + result = self.execute_sql(sql=real_execute_sql) self.save_sql_data(data_obj=result) if in_chat: yield 'data:' + orjson.dumps({'content': 'execute-success', 'type': 'sql-data'}).decode() + '\n\n' + if not stream: + json_result['data'] = result.get('data') + + if finish_step.value <= ChatFinishStep.QUERY_DATA.value: + if stream: + if in_chat: + yield 'data:' + orjson.dumps({'type': 'finish'}).decode() + '\n\n' + else: + data = [] + _fields_list = [] + _fields_skip = False + for _data in result.get('data'): + _row = [] + for field in result.get('fields'): + _row.append(_data.get(field)) + if not _fields_skip: + _fields_list.append(field) + data.append(_row) + _fields_skip = True + + if not data or not _fields_list: + yield 'The SQL execution result is empty.\n\n' + else: + df = pd.DataFrame(np.array(data), columns=_fields_list) + markdown_table = df.to_markdown(index=False) + yield markdown_table + '\n\n' + else: + yield json_result + return # generate chart chart_res = self.generate_chart(chart_type) @@ -1062,39 +1073,47 @@ def run_task(self, in_chat: bool = True): SQLBotLogUtil.info(full_chart_text) chart = self.check_save_chart(res=full_chart_text) SQLBotLogUtil.info(chart) + + if not stream: + json_result['chart'] = chart + if in_chat: yield 'data:' + orjson.dumps( {'content': orjson.dumps(chart).decode(), 'type': 'chart'}).decode() + '\n\n' else: - data = [] - _fields = {} - if chart.get('columns'): - for _column in chart.get('columns'): - if _column: - _fields[_column.get('value')] = _column.get('name') - if chart.get('axis'): - if chart.get('axis').get('x'): - _fields[chart.get('axis').get('x').get('value')] = chart.get('axis').get('x').get('name') - if chart.get('axis').get('y'): - _fields[chart.get('axis').get('y').get('value')] = chart.get('axis').get('y').get('name') - if chart.get('axis').get('series'): - _fields[chart.get('axis').get('series').get('value')] = chart.get('axis').get('series').get( - 'name') - _fields_list = [] - _fields_skip = False - for _data in result.get('data'): - _row = [] - for field in result.get('fields'): - _row.append(_data.get(field)) - if not _fields_skip: - _fields_list.append(field if not _fields.get(field) else _fields.get(field)) - data.append(_row) - _fields_skip = True - df = pd.DataFrame(np.array(data), columns=_fields_list) - markdown_table = df.to_markdown(index=False) - yield markdown_table + '\n\n' - - record = self.finish() + if stream: + data = [] + _fields = {} + if chart.get('columns'): + for _column in chart.get('columns'): + if _column: + _fields[_column.get('value')] = _column.get('name') + if chart.get('axis'): + if chart.get('axis').get('x'): + _fields[chart.get('axis').get('x').get('value')] = chart.get('axis').get('x').get('name') + if chart.get('axis').get('y'): + _fields[chart.get('axis').get('y').get('value')] = chart.get('axis').get('y').get('name') + if chart.get('axis').get('series'): + _fields[chart.get('axis').get('series').get('value')] = chart.get('axis').get('series').get( + 'name') + _fields_list = [] + _fields_skip = False + for _data in result.get('data'): + _row = [] + for field in result.get('fields'): + _row.append(_data.get(field)) + if not _fields_skip: + _fields_list.append(field if not _fields.get(field) else _fields.get(field)) + data.append(_row) + _fields_skip = True + + if not data or not _fields_list: + yield 'The SQL execution result is empty.\n\n' + else: + df = pd.DataFrame(np.array(data), columns=_fields_list) + markdown_table = df.to_markdown(index=False) + yield markdown_table + '\n\n' + if in_chat: yield 'data:' + orjson.dumps({'type': 'finish'}).decode() + '\n\n' else: @@ -1103,8 +1122,16 @@ def run_task(self, in_chat: bool = True): yield '### generated chart picture\n\n' image_url = request_picture(self.record.chat_id, self.record.id, chart, result) SQLBotLogUtil.info(image_url) - yield f'![{chart["type"]}]({image_url})' + if stream: + yield f'![{chart["type"]}]({image_url})' + else: + json_result['image_url'] = image_url + + if not stream: + yield json_result + except Exception as e: + traceback.print_exc() error_msg: str if isinstance(e, SingleMessageError): error_msg = str(e) @@ -1120,7 +1147,14 @@ def run_task(self, in_chat: bool = True): if in_chat: yield 'data:' + orjson.dumps({'content': error_msg, 'type': 'error'}).decode() + '\n\n' else: - yield f'> ❌ **ERROR**\n\n> \n\n> {error_msg}。' + if stream: + yield f'> ❌ **ERROR**\n\n> \n\n> {error_msg}。' + else: + json_result['success'] = False + json_result['message'] = error_msg + yield json_result + finally: + self.finish() def run_recommend_questions_task_async(self): self.future = executor.submit(self.run_recommend_questions_task_cache) @@ -1280,9 +1314,11 @@ def request_picture(chat_id: int, record_id: int, chart: dict, data: dict): return request_path -def get_token_usage(chunk: BaseMessageChunk, token_usage: dict = {}): +def get_token_usage(chunk: BaseMessageChunk, token_usage: dict = None): try: if chunk.usage_metadata: + if token_usage is None: + token_usage = {} token_usage['input_tokens'] = chunk.usage_metadata.get('input_tokens') token_usage['output_tokens'] = chunk.usage_metadata.get('output_tokens') token_usage['total_tokens'] = chunk.usage_metadata.get('total_tokens') @@ -1290,6 +1326,104 @@ def get_token_usage(chunk: BaseMessageChunk, token_usage: dict = {}): pass +def process_stream(res: Iterator[BaseMessageChunk], + token_usage: Dict[str, Any] = None, + enable_tag_parsing: bool = settings.PARSE_REASONING_BLOCK_ENABLED, + start_tag: str = settings.DEFAULT_REASONING_CONTENT_START, + end_tag: str = settings.DEFAULT_REASONING_CONTENT_END + ): + if token_usage is None: + token_usage = {} + in_thinking_block = False # 标记是否在思考过程块中 + current_thinking = '' # 当前收集的思考过程内容 + pending_start_tag = '' # 用于缓存可能被截断的开始标签部分 + + for chunk in res: + SQLBotLogUtil.info(chunk) + reasoning_content_chunk = '' + content = chunk.content + output_content = '' # 实际要输出的内容 + + # 检查additional_kwargs中的reasoning_content + if 'reasoning_content' in chunk.additional_kwargs: + reasoning_content = chunk.additional_kwargs.get('reasoning_content', '') + if reasoning_content is None: + reasoning_content = '' + + # 累积additional_kwargs中的思考内容到current_thinking + current_thinking += reasoning_content + reasoning_content_chunk = reasoning_content + + # 只有当current_thinking不是空字符串时才跳过标签解析 + if not in_thinking_block and current_thinking.strip() != '': + output_content = content # 正常输出content + yield { + 'content': output_content, + 'reasoning_content': reasoning_content_chunk + } + get_token_usage(chunk, token_usage) + continue # 跳过后续的标签解析逻辑 + + # 如果没有有效的思考内容,并且启用了标签解析,才执行标签解析逻辑 + # 如果有缓存的开始标签部分,先拼接当前内容 + if pending_start_tag: + content = pending_start_tag + content + pending_start_tag = '' + + # 检查是否开始思考过程块(处理可能被截断的开始标签) + if enable_tag_parsing and not in_thinking_block and start_tag: + if start_tag in content: + start_idx = content.index(start_tag) + # 只有当开始标签前面没有其他文本时才认为是真正的思考块开始 + if start_idx == 0 or content[:start_idx].strip() == '': + # 完整标签存在且前面没有其他文本 + output_content += content[:start_idx] # 输出开始标签之前的内容 + content = content[start_idx + len(start_tag):] # 移除开始标签 + in_thinking_block = True + else: + # 开始标签前面有其他文本,不认为是思考块开始 + output_content += content + content = '' + else: + # 检查是否可能有部分开始标签 + for i in range(1, len(start_tag)): + if content.endswith(start_tag[:i]): + # 只有当当前内容全是空白时才缓存部分标签 + if content[:-i].strip() == '': + pending_start_tag = start_tag[:i] + content = content[:-i] # 移除可能的部分标签 + output_content += content + content = '' + break + + # 处理思考块内容 + if enable_tag_parsing and in_thinking_block and end_tag: + if end_tag in content: + # 找到结束标签 + end_idx = content.index(end_tag) + current_thinking += content[:end_idx] # 收集思考内容 + reasoning_content_chunk += current_thinking # 添加到当前块的思考内容 + content = content[end_idx + len(end_tag):] # 移除结束标签后的内容 + current_thinking = '' # 重置当前思考内容 + in_thinking_block = False + output_content += content # 输出结束标签之后的内容 + else: + # 在遇到结束标签前,持续收集思考内容 + current_thinking += content + reasoning_content_chunk += content + content = '' + + else: + # 不在思考块中或标签解析未启用,正常输出 + output_content += content + + yield { + 'content': output_content, + 'reasoning_content': reasoning_content_chunk + } + get_token_usage(chunk, token_usage) + + def get_lang_name(lang: str): if lang and lang == 'en': return '英文' diff --git a/backend/apps/data_training/__init__.py b/backend/apps/data_training/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/apps/data_training/api/__init__.py b/backend/apps/data_training/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/apps/data_training/api/data_training.py b/backend/apps/data_training/api/data_training.py new file mode 100644 index 00000000..25422c2b --- /dev/null +++ b/backend/apps/data_training/api/data_training.py @@ -0,0 +1,39 @@ +from typing import Optional + +from fastapi import APIRouter, Query + +from apps.data_training.curd.data_training import page_data_training, create_training, update_training, delete_training +from apps.data_training.models.data_training_model import DataTrainingInfo +from common.core.deps import SessionDep, CurrentUser, Trans + +router = APIRouter(tags=["DataTraining"], prefix="/system/data-training") + + +@router.get("/page/{current_page}/{page_size}") +async def pager(session: SessionDep, current_user: CurrentUser, current_page: int, page_size: int, + question: Optional[str] = Query(None, description="搜索问题(可选)")): + current_page, page_size, total_count, total_pages, _list = page_data_training(session, current_page, page_size, + question, + current_user.oid) + + return { + "current_page": current_page, + "page_size": page_size, + "total_count": total_count, + "total_pages": total_pages, + "data": _list + } + + +@router.put("") +async def create_or_update(session: SessionDep, current_user: CurrentUser, trans: Trans, info: DataTrainingInfo): + oid = current_user.oid + if info.id: + return update_training(session, info, oid, trans) + else: + return create_training(session, info, oid, trans) + + +@router.delete("") +async def delete(session: SessionDep, id_list: list[int]): + delete_training(session, id_list) diff --git a/backend/apps/data_training/curd/__init__.py b/backend/apps/data_training/curd/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/apps/data_training/curd/data_training.py b/backend/apps/data_training/curd/data_training.py new file mode 100644 index 00000000..15260a9a --- /dev/null +++ b/backend/apps/data_training/curd/data_training.py @@ -0,0 +1,317 @@ +import datetime +import logging +import traceback +from typing import List, Optional +from xml.dom.minidom import parseString + +import dicttoxml +from sqlalchemy import and_, select, func, delete, update, or_ +from sqlalchemy import text +from sqlalchemy.orm.session import Session + +from apps.ai_model.embedding import EmbeddingModelCache +from apps.data_training.models.data_training_model import DataTrainingInfo, DataTraining +from apps.datasource.models.datasource import CoreDatasource +from apps.template.generate_chart.generator import get_base_data_training_template +from common.core.config import settings +from common.core.deps import SessionDep, Trans +from common.utils.embedding_threads import run_save_data_training_embeddings + + +def page_data_training(session: SessionDep, current_page: int = 1, page_size: int = 10, name: Optional[str] = None, + oid: Optional[int] = 1): + _list: List[DataTrainingInfo] = [] + + current_page = max(1, current_page) + page_size = max(10, page_size) + + total_count = 0 + total_pages = 0 + + if name and name.strip() != "": + keyword_pattern = f"%{name.strip()}%" + parent_ids_subquery = ( + select(DataTraining.id) + .where(and_(DataTraining.question.ilike(keyword_pattern), DataTraining.oid == oid)) # LIKE查询条件 + ) + else: + parent_ids_subquery = ( + select(DataTraining.id).where(and_(DataTraining.oid == oid)) + ) + + count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery()) + total_count = session.execute(count_stmt).scalar() + total_pages = (total_count + page_size - 1) // page_size + + if current_page > total_pages: + current_page = 1 + + paginated_parent_ids = ( + parent_ids_subquery + .order_by(DataTraining.create_time.desc()) + .offset((current_page - 1) * page_size) + .limit(page_size) + .subquery() + ) + + stmt = ( + select( + DataTraining.id, + DataTraining.oid, + DataTraining.datasource, + CoreDatasource.name, + DataTraining.question, + DataTraining.create_time, + DataTraining.description, + ) + .outerjoin(CoreDatasource, and_(DataTraining.datasource == CoreDatasource.id)) + .where(and_(DataTraining.id.in_(paginated_parent_ids))) + .order_by(DataTraining.create_time.desc()) + ) + + result = session.execute(stmt) + + for row in result: + _list.append(DataTrainingInfo( + id=row.id, + oid=row.oid, + datasource=row.datasource, + datasource_name=row.name, + question=row.question, + create_time=row.create_time, + description=row.description, + )) + + return current_page, page_size, total_count, total_pages, _list + + +def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans): + create_time = datetime.datetime.now() + if info.datasource is None: + raise Exception(trans("i18n_data_training.datasource_cannot_be_none")) + parent = DataTraining(question=info.question, create_time=create_time, description=info.description, oid=oid, + datasource=info.datasource) + + exists = session.query( + session.query(DataTraining).filter( + and_(DataTraining.question == info.question, DataTraining.oid == oid, + DataTraining.datasource == info.datasource)).exists()).scalar() + if exists: + raise Exception(trans("i18n_data_training.exists_in_db")) + + result = DataTraining(**parent.model_dump()) + + session.add(parent) + session.flush() + session.refresh(parent) + + result.id = parent.id + session.commit() + + # embedding + run_save_data_training_embeddings([result.id]) + + return result.id + + +def update_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans): + if info.datasource is None: + raise Exception(trans("i18n_data_training.datasource_cannot_be_none")) + + count = session.query(DataTraining).filter( + DataTraining.id == info.id + ).count() + if count == 0: + raise Exception(trans('i18n_data_training.data_training_not_exists')) + + exists = session.query( + session.query(DataTraining).filter( + and_(DataTraining.question == info.question, DataTraining.oid == oid, + DataTraining.datasource == info.datasource, + DataTraining.id != info.id)).exists()).scalar() + if exists: + raise Exception(trans("i18n_data_training.exists_in_db")) + + stmt = update(DataTraining).where(and_(DataTraining.id == info.id)).values( + question=info.question, + description=info.description, + datasource=info.datasource, + ) + session.execute(stmt) + session.commit() + + # embedding + run_save_data_training_embeddings([info.id]) + + return info.id + + +def delete_training(session: SessionDep, ids: list[int]): + stmt = delete(DataTraining).where(and_(DataTraining.id.in_(ids))) + session.execute(stmt) + session.commit() + + +# def run_save_embeddings(ids: List[int]): +# executor.submit(save_embeddings, ids) +# +# +# def fill_empty_embeddings(): +# executor.submit(run_fill_empty_embeddings) + + +def run_fill_empty_embeddings(session: Session): + if not settings.EMBEDDING_ENABLED: + return + + stmt = select(DataTraining.id).where(and_(DataTraining.embedding.is_(None))) + results = session.execute(stmt).scalars().all() + + save_embeddings(session, results) + + +def save_embeddings(session: Session, ids: List[int]): + if not settings.EMBEDDING_ENABLED: + return + + if not ids or len(ids) == 0: + return + try: + + _list = session.query(DataTraining).filter(and_(DataTraining.id.in_(ids))).all() + + _question_list = [item.question for item in _list] + + model = EmbeddingModelCache.get_model() + + results = model.embed_documents(_question_list) + + for index in range(len(results)): + item = results[index] + stmt = update(DataTraining).where(and_(DataTraining.id == _list[index].id)).values(embedding=item) + session.execute(stmt) + session.commit() + + except Exception: + traceback.print_exc() + + +embedding_sql = f""" +SELECT id, datasource, question, similarity +FROM +(SELECT id, datasource, question, oid, +( 1 - (embedding <=> :embedding_array) ) AS similarity +FROM data_training AS child +) TEMP +WHERE similarity > {settings.EMBEDDING_DATA_TRAINING_SIMILARITY} and oid = :oid and datasource = :datasource +ORDER BY similarity DESC +LIMIT {settings.EMBEDDING_DATA_TRAINING_TOP_COUNT} +""" + + +def select_training_by_question(session: SessionDep, question: str, oid: int, datasource: int): + if question.strip() == "": + return [] + + _list: List[DataTraining] = [] + + # maybe use label later? + stmt = ( + select( + DataTraining.id, + DataTraining.question, + ) + .where( + and_(or_(text(":sentence ILIKE '%' || question || '%'"), text("question ILIKE '%' || :sentence || '%'")), + DataTraining.oid == oid, + DataTraining.datasource == datasource) + ) + ) + + results = session.execute(stmt, {'sentence': question}).fetchall() + + for row in results: + _list.append(DataTraining(id=row.id, question=row.question)) + + if settings.EMBEDDING_ENABLED: + try: + model = EmbeddingModelCache.get_model() + + embedding = model.embed_query(question) + + results = session.execute(text(embedding_sql), + {'embedding_array': str(embedding), 'oid': oid, 'datasource': datasource}) + + for row in results: + _list.append(DataTraining(id=row.id, question=row.question)) + + except Exception: + traceback.print_exc() + + _map: dict = {} + _ids: list[int] = [] + for row in _list: + if row.id in _ids: + continue + else: + _ids.append(row.id) + + if len(_ids) == 0: + return [] + + t_list = session.query(DataTraining.id, DataTraining.datasource, DataTraining.question, + DataTraining.description).filter( + and_(DataTraining.id.in_(_ids))).all() + + for row in t_list: + _map[row.id] = {'question': row.question, 'suggestion-answer': row.description} + + _results: list[dict] = [] + for key in _map.keys(): + _results.append(_map.get(key)) + + return _results + + +def to_xml_string(_dict: list[dict] | dict, root: str = 'sql-examples') -> str: + item_name_func = lambda x: 'sql-example' if x == 'sql-examples' else 'item' + dicttoxml.LOG.setLevel(logging.ERROR) + xml = dicttoxml.dicttoxml(_dict, + cdata=['question', 'suggestion-answer'], + custom_root=root, + item_func=item_name_func, + xml_declaration=False, + encoding='utf-8', + attr_type=False).decode('utf-8') + pretty_xml = parseString(xml).toprettyxml() + + if pretty_xml.startswith('') + 1 + pretty_xml = pretty_xml[end_index:].lstrip() + + # 替换所有 XML 转义字符 + escape_map = { + '<': '<', + '>': '>', + '&': '&', + '"': '"', + ''': "'" + } + for escaped, original in escape_map.items(): + pretty_xml = pretty_xml.replace(escaped, original) + + return pretty_xml + + +def get_training_template(session: SessionDep, question: str, datasource: int, oid: Optional[int] = 1) -> str: + if not oid: + oid = 1 + if not datasource: + return '' + _results = select_training_by_question(session, question, oid, datasource) + if _results and len(_results) > 0: + data_training = to_xml_string(_results) + template = get_base_data_training_template().format(data_training=data_training) + return template + else: + return '' diff --git a/backend/apps/data_training/models/__init__.py b/backend/apps/data_training/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/apps/data_training/models/data_training_model.py b/backend/apps/data_training/models/data_training_model.py new file mode 100644 index 00000000..b2806400 --- /dev/null +++ b/backend/apps/data_training/models/data_training_model.py @@ -0,0 +1,28 @@ +from datetime import datetime +from typing import List, Optional + +from pgvector.sqlalchemy import VECTOR +from pydantic import BaseModel +from sqlalchemy import Column, Text, BigInteger, DateTime, Identity +from sqlmodel import SQLModel, Field + + +class DataTraining(SQLModel, table=True): + __tablename__ = "data_training" + id: Optional[int] = Field(sa_column=Column(BigInteger, Identity(always=True), primary_key=True)) + oid: Optional[int] = Field(sa_column=Column(BigInteger, nullable=True, default=1)) + datasource: Optional[int] = Field(sa_column=Column(BigInteger, nullable=True)) + create_time: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=False), nullable=True)) + question: Optional[str] = Field(max_length=255) + description: Optional[str] = Field(sa_column=Column(Text, nullable=True)) + embedding: Optional[List[float]] = Field(sa_column=Column(VECTOR(), nullable=True)) + + +class DataTrainingInfo(BaseModel): + id: Optional[int] = None + oid: Optional[int] = None + datasource: Optional[int] = None + datasource_name: Optional[str] = None + create_time: Optional[datetime] = None + question: Optional[str] = None + description: Optional[str] = None diff --git a/backend/apps/datasource/api/datasource.py b/backend/apps/datasource/api/datasource.py index 69cad28f..9c182e34 100644 --- a/backend/apps/datasource/api/datasource.py +++ b/backend/apps/datasource/api/datasource.py @@ -193,6 +193,7 @@ def inner(): return await asyncio.to_thread(inner) +# not used @router.post("/fieldEnum/{id}") async def field_enum(session: SessionDep, id: int): def inner(): diff --git a/backend/apps/datasource/api/table_relation.py b/backend/apps/datasource/api/table_relation.py new file mode 100644 index 00000000..6a65060a --- /dev/null +++ b/backend/apps/datasource/api/table_relation.py @@ -0,0 +1,29 @@ +# Author: Junjun +# Date: 2025/9/24 +from typing import List + +from fastapi import APIRouter + +from apps.datasource.models.datasource import CoreDatasource +from common.core.deps import SessionDep + +router = APIRouter(tags=["table_relation"], prefix="/table_relation") + + +@router.post("/save/{ds_id}") +async def save_relation(session: SessionDep, ds_id: int, relation: List[dict]): + ds = session.get(CoreDatasource, ds_id) + if ds: + ds.table_relation = relation + session.commit() + else: + raise Exception("no datasource") + return True + + +@router.post("/get/{ds_id}") +async def save_relation(session: SessionDep, ds_id: int): + ds = session.get(CoreDatasource, ds_id) + if ds: + return ds.table_relation if ds.table_relation else [] + return [] diff --git a/backend/apps/datasource/crud/datasource.py b/backend/apps/datasource/crud/datasource.py index b4720bad..4380b3cf 100644 --- a/backend/apps/datasource/crud/datasource.py +++ b/backend/apps/datasource/crud/datasource.py @@ -4,14 +4,16 @@ from fastapi import HTTPException from sqlalchemy import and_, text +from sqlbot_xpack.permissions.models.ds_rules import DsRules from sqlmodel import select from apps.datasource.crud.permission import get_column_permission_fields, get_row_permission_filters, is_normal_user +from apps.datasource.embedding.table_embedding import get_table_embedding from apps.datasource.utils.utils import aes_decrypt from apps.db.constant import DB from apps.db.db import get_tables, get_fields, exec_sql, check_connection from apps.db.engine import get_engine_config, get_engine_conn -from apps.db.type import db_type_relation +from common.core.config import settings from common.core.deps import SessionDep, CurrentUser, Trans from common.utils.utils import deepcopy_ignore_extra from .table import get_tables_by_ds_id @@ -70,7 +72,7 @@ def create_ds(session: SessionDep, trans: Trans, user: CurrentUser, create_ds: C ds.create_by = user.id ds.oid = user.oid if user.oid is not None else 1 ds.status = "Success" - ds.type_name = db_type_relation()[ds.type] + ds.type_name = DB.get_db(ds.type).db_name record = CoreDatasource(**ds.model_dump()) session.add(record) session.flush() @@ -250,8 +252,9 @@ def preview(session: SessionDep, current_user: CurrentUser, id: int, data: Table f_list = [f for f in data.fields if f.checked] if is_normal_user(current_user): # column is checked, and, column permission for data.fields + contain_rules = session.query(DsRules).all() f_list = get_column_permission_fields(session=session, current_user=current_user, table=data.table, - fields=f_list) + fields=f_list, contain_rules=contain_rules) # row permission tree where_str = '' @@ -297,7 +300,7 @@ def preview(session: SessionDep, current_user: CurrentUser, id: int, data: Table sql = f"""SELECT "{'", "'.join(fields)}" FROM "{data.table.table_name}" {where} LIMIT 100""" - return exec_sql(ds, sql, True, [data.table.table_name]) + return exec_sql(ds, sql, True) def fieldEnum(session: SessionDep, id: int): @@ -313,7 +316,7 @@ def fieldEnum(session: SessionDep, id: int): db = DB.get_db(ds.type) sql = f"""SELECT DISTINCT {db.prefix}{field.field_name}{db.suffix} FROM {db.prefix}{table.table_name}{db.suffix}""" - res = exec_sql(ds, sql, True, [table.table_name]) + res = exec_sql(ds, sql, True) return [item.get(res.get('fields')[0]) for item in res.get('data')] @@ -336,42 +339,122 @@ def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: Core tables = session.query(CoreTable).filter(CoreTable.ds_id == ds.id).all() conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() schema = conf.dbSchema if conf.dbSchema is not None and conf.dbSchema != "" else conf.database + + # get all field + table_ids = [table.id for table in tables] + all_fields = session.query(CoreField).filter( + and_(CoreField.table_id.in_(table_ids), CoreField.checked == True)).all() + # build dict + fields_dict = {} + for field in all_fields: + if fields_dict.get(field.table_id): + fields_dict.get(field.table_id).append(field) + else: + fields_dict[field.table_id] = [field] + + contain_rules = session.query(DsRules).all() for table in tables: - fields = session.query(CoreField).filter(and_(CoreField.table_id == table.id, CoreField.checked == True)).all() + # fields = session.query(CoreField).filter(and_(CoreField.table_id == table.id, CoreField.checked == True)).all() + fields = fields_dict.get(table.id) # do column permissions, filter fields - fields = get_column_permission_fields(session=session, current_user=current_user, table=table, fields=fields) + fields = get_column_permission_fields(session=session, current_user=current_user, table=table, fields=fields, + contain_rules=contain_rules) _list.append(TableAndFields(schema=schema, table=table, fields=fields)) return _list -def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource) -> str: +def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, question: str, + embedding: bool = True) -> str: schema_str = "" table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds) if len(table_objs) == 0: return schema_str db_name = table_objs[0].schema schema_str += f"【DB_ID】 {db_name}\n【Schema】\n" + tables = [] + all_tables = [] # temp save all tables for obj in table_objs: - schema_str += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {obj.table.table_name}" + schema_table = '' + schema_table += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {obj.table.table_name}" table_comment = '' if obj.table.custom_comment: table_comment = obj.table.custom_comment.strip() if table_comment == '': - schema_str += '\n[\n' + schema_table += '\n[\n' else: - schema_str += f", {table_comment}\n[\n" - - field_list = [] - for field in obj.fields: - field_comment = '' - if field.custom_comment: - field_comment = field.custom_comment.strip() - if field_comment == '': - field_list.append(f"({field.field_name}:{field.field_type})") - else: - field_list.append(f"({field.field_name}:{field.field_type}, {field_comment})") - schema_str += ",\n".join(field_list) - schema_str += '\n]\n' - # todo 外键 + schema_table += f", {table_comment}\n[\n" + + if obj.fields: + field_list = [] + for field in obj.fields: + field_comment = '' + if field.custom_comment: + field_comment = field.custom_comment.strip() + if field_comment == '': + field_list.append(f"({field.field_name}:{field.field_type})") + else: + field_list.append(f"({field.field_name}:{field.field_type}, {field_comment})") + schema_table += ",\n".join(field_list) + schema_table += '\n]\n' + + t_obj = {"id": obj.table.id, "schema_table": schema_table} + tables.append(t_obj) + all_tables.append(t_obj) + + # do table embedding + if embedding and tables and settings.TABLE_EMBEDDING_ENABLED: + tables = get_table_embedding(session, current_user, tables, question) + # splice schema + if tables: + for s in tables: + schema_str += s.get('schema_table') + + # field relation + if tables and ds.table_relation: + relations = list(filter(lambda x: x.get('shape') == 'edge', ds.table_relation)) + if relations: + # Complete the missing table + # get tables in relation, remove irrelevant relation + embedding_table_ids = [s.get('id') for s in tables] + all_relations = list( + filter(lambda x: x.get('source').get('cell') in embedding_table_ids or x.get('target').get( + 'cell') in embedding_table_ids, relations)) + + # get relation table ids, sub embedding table ids + relation_table_ids = [] + for r in all_relations: + relation_table_ids.append(r.get('source').get('cell')) + relation_table_ids.append(r.get('target').get('cell')) + relation_table_ids = list(set(relation_table_ids)) + # get table dict + table_records = session.query(CoreTable).filter(CoreTable.id.in_(list(map(int, relation_table_ids)))).all() + table_dict = {} + for ele in table_records: + table_dict[ele.id] = ele.table_name + + # get lost table ids + lost_table_ids = list(set(relation_table_ids) - set(embedding_table_ids)) + # get lost table schema and splice it + lost_tables = list(filter(lambda x: x.get('id') in lost_table_ids, all_tables)) + if lost_tables: + for s in lost_tables: + schema_str += s.get('schema_table') + + # get field dict + relation_field_ids = [] + for relation in all_relations: + relation_field_ids.append(relation.get('source').get('port')) + relation_field_ids.append(relation.get('target').get('port')) + relation_field_ids = list(set(relation_field_ids)) + field_records = session.query(CoreField).filter(CoreField.id.in_(list(map(int, relation_field_ids)))).all() + field_dict = {} + for ele in field_records: + field_dict[ele.id] = ele.field_name + + if all_relations: + schema_str += '【Foreign keys】\n' + for ele in all_relations: + schema_str += f"{table_dict.get(int(ele.get('source').get('cell')))}.{field_dict.get(int(ele.get('source').get('port')))}={table_dict.get(int(ele.get('target').get('cell')))}.{field_dict.get(int(ele.get('target').get('port')))}\n" + return schema_str diff --git a/backend/apps/datasource/crud/permission.py b/backend/apps/datasource/crud/permission.py index 80c32efd..87071812 100644 --- a/backend/apps/datasource/crud/permission.py +++ b/backend/apps/datasource/crud/permission.py @@ -2,13 +2,15 @@ from typing import List, Optional from sqlalchemy import and_ -from apps.datasource.crud.row_permission import transFilterTree -from apps.datasource.models.datasource import CoreDatasource, CoreField, CoreTable -from common.core.deps import CurrentUser, SessionDep from sqlbot_xpack.permissions.api.permission import transRecord2DTO from sqlbot_xpack.permissions.models.ds_permission import DsPermission, PermissionDTO from sqlbot_xpack.permissions.models.ds_rules import DsRules +from apps.datasource.crud.row_permission import transFilterTree +from apps.datasource.models.datasource import CoreDatasource, CoreField, CoreTable +from common.core.deps import CurrentUser, SessionDep + + def get_row_permission_filters(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, tables: Optional[list] = None, single_table: Optional[CoreTable] = None): if single_table: @@ -20,10 +22,10 @@ def get_row_permission_filters(session: SessionDep, current_user: CurrentUser, d filters = [] if is_normal_user(current_user): + contain_rules = session.query(DsRules).all() for table in table_list: row_permissions = session.query(DsPermission).filter( and_(DsPermission.table_id == table.id, DsPermission.type == 'row')).all() - contain_rules = session.query(DsRules).all() res: List[PermissionDTO] = [] if row_permissions is not None: for permission in row_permissions: @@ -35,6 +37,7 @@ def get_row_permission_filters(session: SessionDep, current_user: CurrentUser, d if p_list is not None and u_list is not None and permission.id in p_list and ( current_user.id in u_list or f'{current_user.id}' in u_list): flag = True + break if flag: res.append(transRecord2DTO(session, permission)) where_str = transFilterTree(session, res, ds) @@ -43,11 +46,10 @@ def get_row_permission_filters(session: SessionDep, current_user: CurrentUser, d def get_column_permission_fields(session: SessionDep, current_user: CurrentUser, table: CoreTable, - fields: list[CoreField]): + fields: list[CoreField], contain_rules: list[DsRules]): if is_normal_user(current_user): column_permissions = session.query(DsPermission).filter( and_(DsPermission.table_id == table.id, DsPermission.type == 'column')).all() - contain_rules = session.query(DsRules).all() if column_permissions is not None: for permission in column_permissions: # check permission and user in same rules @@ -63,6 +65,7 @@ def get_column_permission_fields(session: SessionDep, current_user: CurrentUser, if p_list is not None and u_list is not None and permission.id in p_list and ( current_user.id in u_list or f'{current_user.id}' in u_list): flag = True + break if flag: permission_list = json.loads(permission.permissions) fields = filter_list(fields, permission_list) diff --git a/backend/apps/datasource/embedding/__init__.py b/backend/apps/datasource/embedding/__init__.py new file mode 100644 index 00000000..87bb6f5d --- /dev/null +++ b/backend/apps/datasource/embedding/__init__.py @@ -0,0 +1,2 @@ +# Author: Junjun +# Date: 2025/9/18 diff --git a/backend/apps/datasource/embedding/ds_embedding.py b/backend/apps/datasource/embedding/ds_embedding.py new file mode 100644 index 00000000..7aad7d29 --- /dev/null +++ b/backend/apps/datasource/embedding/ds_embedding.py @@ -0,0 +1,59 @@ +# Author: Junjun +# Date: 2025/9/18 +import json +import traceback +from typing import Optional + +from apps.ai_model.embedding import EmbeddingModelCache +from apps.datasource.crud.datasource import get_table_schema +from apps.datasource.embedding.utils import cosine_similarity +from apps.datasource.models.datasource import CoreDatasource +from apps.system.crud.assistant import AssistantOutDs +from common.core.deps import CurrentAssistant +from common.core.deps import SessionDep, CurrentUser +from common.utils.utils import SQLBotLogUtil + + +def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, out_ds: AssistantOutDs, + question: str, + current_assistant: Optional[CurrentAssistant] = None): + _list = [] + if current_assistant and current_assistant.type != 4: + if out_ds.ds_list: + for _ds in out_ds.ds_list: + ds = out_ds.get_ds(_ds.id) + table_schema = out_ds.get_db_schema(_ds.id) + ds_info = f"{ds.name}, {ds.description}\n" + ds_schema = ds_info + table_schema + _list.append({"id": ds.id, "ds_schema": ds_schema, "cosine_similarity": 0.0, "ds": ds}) + else: + for _ds in _ds_list: + if _ds.get('id'): + ds = session.get(CoreDatasource, _ds.get('id')) + table_schema = get_table_schema(session, current_user, ds, question, embedding=False) + ds_info = f"{ds.name}, {ds.description}\n" + ds_schema = ds_info + table_schema + _list.append({"id": ds.id, "ds_schema": ds_schema, "cosine_similarity": 0.0, "ds": ds}) + + if _list: + try: + text = [s.get('ds_schema') for s in _list] + + model = EmbeddingModelCache.get_model() + results = model.embed_documents(text) + + q_embedding = model.embed_query(question) + for index in range(len(results)): + item = results[index] + _list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item) + + _list.sort(key=lambda x: x['cosine_similarity'], reverse=True) + # print(len(_list)) + SQLBotLogUtil.info(json.dumps( + [{"id": ele.get("id"), "name": ele.get("ds").name, "cosine_similarity": ele.get("cosine_similarity")} + for ele in _list])) + ds = _list[0].get('ds') + return {"id": ds.id, "name": ds.name, "description": ds.description} + except Exception: + traceback.print_exc() + return _list diff --git a/backend/apps/datasource/embedding/table_embedding.py b/backend/apps/datasource/embedding/table_embedding.py new file mode 100644 index 00000000..1a3fe896 --- /dev/null +++ b/backend/apps/datasource/embedding/table_embedding.py @@ -0,0 +1,41 @@ +# Author: Junjun +# Date: 2025/9/23 +import json +import time +import traceback + +from apps.ai_model.embedding import EmbeddingModelCache +from apps.datasource.embedding.utils import cosine_similarity +from common.core.config import settings +from common.core.deps import SessionDep, CurrentUser +from common.utils.utils import SQLBotLogUtil + + +def get_table_embedding(session: SessionDep, current_user: CurrentUser, tables: list[dict], question: str): + _list = [] + for table in tables: + _list.append({"id": table.get('id'), "schema_table": table.get('schema_table'), "cosine_similarity": 0.0}) + + if _list: + try: + text = [s.get('schema_table') for s in _list] + + model = EmbeddingModelCache.get_model() + start_time = time.time() + results = model.embed_documents(text) + end_time = time.time() + SQLBotLogUtil.info(str(end_time - start_time)) + + q_embedding = model.embed_query(question) + for index in range(len(results)): + item = results[index] + _list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item) + + _list.sort(key=lambda x: x['cosine_similarity'], reverse=True) + _list = _list[:settings.TABLE_EMBEDDING_COUNT] + # print(len(_list)) + SQLBotLogUtil.info(json.dumps(_list)) + return _list + except Exception: + traceback.print_exc() + return _list diff --git a/backend/apps/datasource/embedding/utils.py b/backend/apps/datasource/embedding/utils.py new file mode 100644 index 00000000..3f6ddced --- /dev/null +++ b/backend/apps/datasource/embedding/utils.py @@ -0,0 +1,18 @@ +# Author: Junjun +# Date: 2025/9/23 +import math + + +def cosine_similarity(vec_a, vec_b): + if len(vec_a) != len(vec_b): + raise ValueError("The vector dimension must be the same") + + dot_product = sum(a * b for a, b in zip(vec_a, vec_b)) + + norm_a = math.sqrt(sum(a * a for a in vec_a)) + norm_b = math.sqrt(sum(b * b for b in vec_b)) + + if norm_a == 0 or norm_b == 0: + return 0.0 + + return dot_product / (norm_a * norm_b) diff --git a/backend/apps/datasource/models/datasource.py b/backend/apps/datasource/models/datasource.py index 6496a857..78f23916 100644 --- a/backend/apps/datasource/models/datasource.py +++ b/backend/apps/datasource/models/datasource.py @@ -2,13 +2,14 @@ from typing import List, Optional from pydantic import BaseModel -from sqlalchemy import Column, Text, BigInteger, DateTime, Integer, Identity +from sqlalchemy import Column, Text, BigInteger, DateTime, Identity +from sqlalchemy.dialects.postgresql import JSONB from sqlmodel import SQLModel, Field class CoreDatasource(SQLModel, table=True): __tablename__ = "core_datasource" - id: int = Field(sa_column=Column(Integer, Identity(always=True), nullable=False, primary_key=True)) + id: int = Field(sa_column=Column(BigInteger, Identity(always=True), nullable=False, primary_key=True)) name: str = Field(max_length=128, nullable=False) description: str = Field(max_length=512, nullable=True) type: str = Field(max_length=64) @@ -19,11 +20,12 @@ class CoreDatasource(SQLModel, table=True): status: str = Field(max_length=64, nullable=True) num: str = Field(max_length=256, nullable=True) oid: int = Field(sa_column=Column(BigInteger())) + table_relation: List = Field(sa_column=Column(JSONB, nullable=True)) class CoreTable(SQLModel, table=True): __tablename__ = "core_table" - id: int = Field(sa_column=Column(Integer, Identity(always=True), nullable=False, primary_key=True)) + id: int = Field(sa_column=Column(BigInteger, Identity(always=True), nullable=False, primary_key=True)) ds_id: int = Field(sa_column=Column(BigInteger())) checked: bool = Field(default=True) table_name: str = Field(sa_column=Column(Text)) @@ -33,7 +35,7 @@ class CoreTable(SQLModel, table=True): class CoreField(SQLModel, table=True): __tablename__ = "core_field" - id: int = Field(sa_column=Column(Integer, Identity(always=True), nullable=False, primary_key=True)) + id: int = Field(sa_column=Column(BigInteger, Identity(always=True), nullable=False, primary_key=True)) ds_id: int = Field(sa_column=Column(BigInteger())) table_id: int = Field(sa_column=Column(BigInteger())) checked: bool = Field(default=True) diff --git a/backend/apps/db/constant.py b/backend/apps/db/constant.py index 67074ac1..4194d8b7 100644 --- a/backend/apps/db/constant.py +++ b/backend/apps/db/constant.py @@ -13,19 +13,20 @@ def __init__(self, type_name): class DB(Enum): - mysql = ('mysql', '`', '`', ConnectType.sqlalchemy) - sqlServer = ('sqlServer', '[', ']', ConnectType.sqlalchemy) - pg = ('pg', '"', '"', ConnectType.sqlalchemy) - excel = ('excel', '"', '"', ConnectType.sqlalchemy) - oracle = ('oracle', '"', '"', ConnectType.sqlalchemy) - ck = ('ck', '"', '"', ConnectType.sqlalchemy) - dm = ('dm', '"', '"', ConnectType.py_driver) - doris = ('doris', '`', '`', ConnectType.py_driver) - redshift = ('redshift', '"', '"', ConnectType.py_driver) - es = ('es', '"', '"', ConnectType.py_driver) - - def __init__(self, type, prefix, suffix, connect_type: ConnectType): + mysql = ('mysql', 'MySQL', '`', '`', ConnectType.sqlalchemy) + sqlServer = ('sqlServer', 'Microsoft SQL Server', '[', ']', ConnectType.sqlalchemy) + pg = ('pg', 'PostgreSQL', '"', '"', ConnectType.sqlalchemy) + excel = ('excel', 'Excel/CSV', '"', '"', ConnectType.sqlalchemy) + oracle = ('oracle', 'Oracle', '"', '"', ConnectType.sqlalchemy) + ck = ('ck', 'ClickHouse', '"', '"', ConnectType.sqlalchemy) + dm = ('dm', '达梦', '"', '"', ConnectType.py_driver) + doris = ('doris', 'Apache Doris', '`', '`', ConnectType.py_driver) + redshift = ('redshift', 'AWS Redshift', '"', '"', ConnectType.py_driver) + es = ('es', 'Elasticsearch', '"', '"', ConnectType.py_driver) + + def __init__(self, type, db_name, prefix, suffix, connect_type: ConnectType): self.type = type + self.db_name = db_name self.prefix = prefix self.suffix = suffix self.connect_type = connect_type diff --git a/backend/apps/db/db.py b/backend/apps/db/db.py index a1da1492..1a22dc35 100644 --- a/backend/apps/db/db.py +++ b/backend/apps/db/db.py @@ -5,6 +5,8 @@ from decimal import Decimal from typing import Optional +import pymssql + from apps.db.db_sql import get_table_sql, get_field_sql, get_version_sql from common.error import ParseSQLResultError @@ -24,7 +26,7 @@ from common.core.deps import Trans from common.utils.utils import SQLBotLogUtil from fastapi import HTTPException -from apps.db.es_engine import get_es_connect, get_es_index, get_es_fields, get_es_data +from apps.db.es_engine import get_es_connect, get_es_index, get_es_fields, get_es_data_by_http def get_uri(ds: CoreDatasource) -> str: @@ -70,6 +72,35 @@ def get_uri_from_config(type: str, conf: DatasourceConf) -> str: return db_url +def get_extra_config(conf: DatasourceConf): + config_dict = {} + if conf.extraJdbc: + config_arr = conf.extraJdbc.split("&") + for config in config_arr: + kv = config.split("=") + if len(kv) == 2 and kv[0] and kv[1]: + config_dict[kv[0]] = kv[1] + else: + raise Exception(f'param: {config} is error') + return config_dict + + +def get_origin_connect(type: str, conf: DatasourceConf): + extra_config_dict = get_extra_config(conf) + if type == "sqlServer": + return pymssql.connect( + server=conf.host, + port=str(conf.port), + user=conf.username, + password=conf.password, + database=conf.database, + timeout=conf.timeout, + tds_version='7.0', # options: '4.2', '7.0', '8.0' ..., + **extra_config_dict + ) + + +# use sqlalchemy def get_engine(ds: CoreDatasource, timeout: int = 0) -> Engine: conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() if conf.timeout is None: @@ -87,7 +118,8 @@ def get_engine(ds: CoreDatasource, timeout: int = 0) -> Engine: connect_args={"connect_timeout": conf.timeout}, pool_timeout=conf.timeout) elif ds.type == 'sqlServer': - engine = create_engine(get_uri(ds), pool_timeout=conf.timeout) + engine = create_engine('mssql+pymssql://', creator=lambda: get_origin_connect(ds.type, conf), + pool_timeout=conf.timeout) elif ds.type == 'oracle': engine = create_engine(get_uri(ds), pool_timeout=conf.timeout) @@ -119,9 +151,10 @@ def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDs return False else: conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) + extra_config_dict = get_extra_config(conf) if ds.type == 'dm': with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, - port=conf.port) as conn, conn.cursor() as cursor: + port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor: try: cursor.execute('select 1', timeout=10).fetchall() SQLBotLogUtil.info("success") @@ -134,7 +167,7 @@ def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDs elif ds.type == 'doris': with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, port=conf.port, db=conf.database, connect_timeout=10, - read_timeout=10) as conn, conn.cursor() as cursor: + read_timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor: try: cursor.execute('select 1') SQLBotLogUtil.info("success") @@ -148,7 +181,7 @@ def check_connection(trans: Optional[Trans], ds: CoreDatasource | AssistantOutDs with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, password=conf.password, - timeout=10) as conn, conn.cursor() as cursor: + timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor: try: cursor.execute('select 1') SQLBotLogUtil.info("success") @@ -205,16 +238,17 @@ def get_version(ds: CoreDatasource | AssistantOutDsSchema): res = result.fetchall() version = res[0][0] else: + extra_config_dict = get_extra_config(conf) if ds.type == 'dm': with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, port=conf.port) as conn, conn.cursor() as cursor: - cursor.execute(sql, timeout=10) + cursor.execute(sql, timeout=10, **extra_config_dict) res = cursor.fetchall() version = res[0][0] elif ds.type == 'doris': with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, port=conf.port, db=conf.database, connect_timeout=10, - read_timeout=10) as conn, conn.cursor() as cursor: + read_timeout=10, **extra_config_dict) as conn, conn.cursor() as cursor: cursor.execute(sql) res = cursor.fetchall() version = res[0][0] @@ -233,29 +267,30 @@ def get_schema(ds: CoreDatasource): with get_session(ds) as session: sql: str = '' if ds.type == "sqlServer": - sql = f"""select name from sys.schemas""" + sql = """select name from sys.schemas""" elif ds.type == "pg" or ds.type == "excel": sql = """SELECT nspname FROM pg_namespace""" elif ds.type == "oracle": - sql = f"""select * from all_users""" + sql = """select * from all_users""" with session.execute(text(sql)) as result: res = result.fetchall() res_list = [item[0] for item in res] return res_list else: + extra_config_dict = get_extra_config(conf) if ds.type == 'dm': with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, - port=conf.port) as conn, conn.cursor() as cursor: - cursor.execute(f"""select OBJECT_NAME from dba_objects where object_type='SCH'""", timeout=conf.timeout) + port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor: + cursor.execute("""select OBJECT_NAME from dba_objects where object_type='SCH'""", timeout=conf.timeout) res = cursor.fetchall() res_list = [item[0] for item in res] return res_list elif ds.type == 'redshift': with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, password=conf.password, - timeout=conf.timeout) as conn, conn.cursor() as cursor: - cursor.execute(f"""SELECT nspname FROM pg_namespace""") + timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: + cursor.execute("""SELECT nspname FROM pg_namespace""") res = cursor.fetchall() res_list = [item[0] for item in res] return res_list @@ -264,34 +299,35 @@ def get_schema(ds: CoreDatasource): def get_tables(ds: CoreDatasource): conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() db = DB.get_db(ds.type) - sql = get_table_sql(ds, conf, get_version(ds)) + sql, sql_param = get_table_sql(ds, conf, get_version(ds)) if db.connect_type == ConnectType.sqlalchemy: with get_session(ds) as session: - with session.execute(text(sql)) as result: + with session.execute(text(sql), {"param": sql_param}) as result: res = result.fetchall() res_list = [TableSchema(*item) for item in res] return res_list else: + extra_config_dict = get_extra_config(conf) if ds.type == 'dm': with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, - port=conf.port) as conn, conn.cursor() as cursor: - cursor.execute(sql, timeout=conf.timeout) + port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor: + cursor.execute(sql, {"param": sql_param}, timeout=conf.timeout) res = cursor.fetchall() res_list = [TableSchema(*item) for item in res] return res_list elif ds.type == 'doris': with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, port=conf.port, db=conf.database, connect_timeout=conf.timeout, - read_timeout=conf.timeout) as conn, conn.cursor() as cursor: - cursor.execute(sql) + read_timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: + cursor.execute(sql, (sql_param,)) res = cursor.fetchall() res_list = [TableSchema(*item) for item in res] return res_list elif ds.type == 'redshift': with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, password=conf.password, - timeout=conf.timeout) as conn, conn.cursor() as cursor: - cursor.execute(sql) + timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: + cursor.execute(sql, (sql_param,)) res = cursor.fetchall() res_list = [TableSchema(*item) for item in res] return res_list @@ -304,34 +340,35 @@ def get_tables(ds: CoreDatasource): def get_fields(ds: CoreDatasource, table_name: str = None): conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() db = DB.get_db(ds.type) - sql = get_field_sql(ds, conf, table_name) + sql, p1, p2 = get_field_sql(ds, conf, table_name) if db.connect_type == ConnectType.sqlalchemy: with get_session(ds) as session: - with session.execute(text(sql)) as result: + with session.execute(text(sql), {"param1": p1, "param2": p2}) as result: res = result.fetchall() res_list = [ColumnSchema(*item) for item in res] return res_list else: + extra_config_dict = get_extra_config(conf) if ds.type == 'dm': with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, - port=conf.port) as conn, conn.cursor() as cursor: - cursor.execute(sql, timeout=conf.timeout) + port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor: + cursor.execute(sql, {"param1": p1, "param2": p2}, timeout=conf.timeout) res = cursor.fetchall() res_list = [ColumnSchema(*item) for item in res] return res_list elif ds.type == 'doris': with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, port=conf.port, db=conf.database, connect_timeout=conf.timeout, - read_timeout=conf.timeout) as conn, conn.cursor() as cursor: - cursor.execute(sql) + read_timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: + cursor.execute(sql, (p1, p2)) res = cursor.fetchall() res_list = [ColumnSchema(*item) for item in res] return res_list elif ds.type == 'redshift': with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, password=conf.password, - timeout=conf.timeout) as conn, conn.cursor() as cursor: - cursor.execute(sql) + timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: + cursor.execute(sql, (p1, p2)) res = cursor.fetchall() res_list = [ColumnSchema(*item) for item in res] return res_list @@ -341,7 +378,7 @@ def get_fields(ds: CoreDatasource, table_name: str = None): return res_list -def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=False, table_name=None): +def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=False): while sql.endswith(';'): sql = sql[:-1] @@ -363,9 +400,10 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column= raise ParseSQLResultError(str(ex)) else: conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) + extra_config_dict = get_extra_config(conf) if ds.type == 'dm': with dmPython.connect(user=conf.username, password=conf.password, server=conf.host, - port=conf.port) as conn, conn.cursor() as cursor: + port=conf.port, **extra_config_dict) as conn, conn.cursor() as cursor: try: cursor.execute(sql, timeout=conf.timeout) res = cursor.fetchall() @@ -384,7 +422,7 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column= elif ds.type == 'doris': with pymysql.connect(user=conf.username, passwd=conf.password, host=conf.host, port=conf.port, db=conf.database, connect_timeout=conf.timeout, - read_timeout=conf.timeout) as conn, conn.cursor() as cursor: + read_timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: try: cursor.execute(sql) res = cursor.fetchall() @@ -403,7 +441,7 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column= elif ds.type == 'redshift': with redshift_connector.connect(host=conf.host, port=conf.port, database=conf.database, user=conf.username, password=conf.password, - timeout=conf.timeout) as conn, conn.cursor() as cursor: + timeout=conf.timeout, **extra_config_dict) as conn, conn.cursor() as cursor: try: cursor.execute(sql) res = cursor.fetchall() @@ -421,17 +459,16 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column= raise ParseSQLResultError(str(ex)) elif ds.type == 'es': try: - if table_name and table_name[0]: - res, columns = get_es_data(conf, sql, table_name[0]) - columns = [field[0] for field in columns] if origin_column else [field[0].lower() for - field in - columns] - result_list = [ - {str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in - enumerate(tuple_item)} - for tuple_item in res - ] - return {"fields": columns, "data": result_list, - "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))} + res, columns = get_es_data_by_http(conf, sql) + columns = [field.get('name') for field in columns] if origin_column else [field.get('name').lower() for + field in + columns] + result_list = [ + {str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in + enumerate(tuple(tuple_item))} + for tuple_item in res + ] + return {"fields": columns, "data": result_list, + "sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))} except Exception as ex: raise Exception(str(ex)) diff --git a/backend/apps/db/db_sql.py b/backend/apps/db/db_sql.py index 10af38f8..66616b70 100644 --- a/backend/apps/db/db_sql.py +++ b/backend/apps/db/db_sql.py @@ -5,27 +5,27 @@ def get_version_sql(ds: CoreDatasource, conf: DatasourceConf): if ds.type == "mysql" or ds.type == "doris": - return f""" + return """ SELECT VERSION() """ elif ds.type == "sqlServer": - return f""" + return """ select SERVERPROPERTY('ProductVersion') """ elif ds.type == "pg" or ds.type == "excel": - return f""" + return """ SELECT current_setting('server_version') """ elif ds.type == "oracle": - return f""" + return """ SELECT version FROM v$instance """ elif ds.type == "ck": - return f""" + return """ select version() """ elif ds.type == 'dm': - return f""" + return """ SELECT * FROM v$version """ elif ds.type == 'redshift': @@ -33,18 +33,18 @@ def get_version_sql(ds: CoreDatasource, conf: DatasourceConf): def get_table_sql(ds: CoreDatasource, conf: DatasourceConf, db_version: str = ''): - if ds.type == "mysql" or ds.type == "doris": - return f""" + if ds.type == "mysql": + return """ SELECT TABLE_NAME, TABLE_COMMENT FROM information_schema.TABLES WHERE - TABLE_SCHEMA = '{conf.database}' - """ + TABLE_SCHEMA = :param + """, conf.database elif ds.type == "sqlServer": - return f""" + return """ SELECT TABLE_NAME AS [TABLE_NAME], ISNULL(ep.value, '') AS [TABLE_COMMENT] @@ -57,10 +57,10 @@ def get_table_sql(ds: CoreDatasource, conf: DatasourceConf, db_version: str = '' AND ep.name = 'MS_Description' WHERE t.TABLE_TYPE IN ('BASE TABLE', 'VIEW') - AND t.TABLE_SCHEMA = '{conf.dbSchema}' - """ + AND t.TABLE_SCHEMA = :param + """, conf.dbSchema elif ds.type == "pg" or ds.type == "excel": - return f""" + return """ SELECT c.relname AS TABLE_NAME, COALESCE(d.description, obj_description(c.oid)) AS TABLE_COMMENT FROM pg_class c @@ -68,59 +68,59 @@ def get_table_sql(ds: CoreDatasource, conf: DatasourceConf, db_version: str = '' pg_namespace n ON n.oid = c.relnamespace LEFT JOIN pg_description d ON d.objoid = c.oid AND d.objsubid = 0 - WHERE n.nspname = '{conf.dbSchema}' + WHERE n.nspname = :param AND c.relkind IN ('r', 'v', 'p', 'm') AND c.relname NOT LIKE 'pg_%' AND c.relname NOT LIKE 'sql_%' ORDER BY c.relname \ - """ + """, conf.dbSchema elif ds.type == "oracle": - return f""" + return """ SELECT t.TABLE_NAME AS "TABLE_NAME", NVL(c.COMMENTS, '') AS "TABLE_COMMENT" FROM ( SELECT TABLE_NAME, 'TABLE' AS OBJECT_TYPE FROM DBA_TABLES - WHERE OWNER = '{conf.dbSchema}' + WHERE OWNER = :param UNION ALL SELECT VIEW_NAME AS TABLE_NAME, 'VIEW' AS OBJECT_TYPE FROM DBA_VIEWS - WHERE OWNER = '{conf.dbSchema}' + WHERE OWNER = :param ) t LEFT JOIN DBA_TAB_COMMENTS c ON t.TABLE_NAME = c.TABLE_NAME AND c.TABLE_TYPE = t.OBJECT_TYPE - AND c.OWNER = '{conf.dbSchema}' + AND c.OWNER = :param ORDER BY t.TABLE_NAME - """ + """, conf.dbSchema elif ds.type == "ck": version = int(db_version.split('.')[0]) if version < 22: - return f""" + return """ SELECT name, null as comment FROM system.tables - WHERE database = '{conf.database}' + WHERE database = :param AND engine NOT IN ('Dictionary') ORDER BY name - """ + """, conf.database else: - return f""" + return """ SELECT name, comment FROM system.tables - WHERE database = '{conf.database}' + WHERE database = :param AND engine NOT IN ('Dictionary') ORDER BY name - """ + """, conf.database elif ds.type == 'dm': - return f""" + return """ select table_name, comments from all_tab_comments - where owner='{conf.dbSchema}' + where owner=:param AND (table_type = 'TABLE' or table_type = 'VIEW') - """ + """, conf.dbSchema elif ds.type == 'redshift': - return f""" + return """ SELECT relname AS TableName, obj_description(relfilenode::regclass, 'pg_class') AS TableDescription @@ -128,13 +128,25 @@ def get_table_sql(ds: CoreDatasource, conf: DatasourceConf, db_version: str = '' pg_class WHERE relkind in ('r','p', 'f') - AND relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = '{conf.dbSchema}') - """ + AND relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = %s) + """, conf.dbSchema + elif ds.type == "doris": + return """ + SELECT + TABLE_NAME, + TABLE_COMMENT + FROM + information_schema.TABLES + WHERE + TABLE_SCHEMA = %s + """, conf.database + elif ds.type == "es": + return "", None def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = None): - if ds.type == "mysql" or ds.type == "doris": - sql1 = f""" + if ds.type == "mysql": + sql1 = """ SELECT COLUMN_NAME, DATA_TYPE, @@ -142,12 +154,12 @@ def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = No FROM INFORMATION_SCHEMA.COLUMNS WHERE - TABLE_SCHEMA = '{conf.database}' + TABLE_SCHEMA = :param1 """ - sql2 = f" AND TABLE_NAME = '{table_name}'" if table_name is not None and table_name != "" else "" - return sql1 + sql2 + sql2 = " AND TABLE_NAME = :param2" if table_name is not None and table_name != "" else "" + return sql1 + sql2, conf.database, table_name elif ds.type == "sqlServer": - sql1 = f""" + sql1 = """ SELECT COLUMN_NAME AS [COLUMN_NAME], DATA_TYPE AS [DATA_TYPE], @@ -160,12 +172,28 @@ def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = No AND EP.minor_id = C.ORDINAL_POSITION AND EP.name = 'MS_Description' WHERE - C.TABLE_SCHEMA = '{conf.dbSchema}' + C.TABLE_SCHEMA = :param1 """ - sql2 = f" AND C.TABLE_NAME = '{table_name}'" if table_name is not None and table_name != "" else "" - return sql1 + sql2 - elif ds.type == "pg" or ds.type == "excel" or ds.type == "redshift": - sql1 = f""" + sql2 = " AND C.TABLE_NAME = :param2" if table_name is not None and table_name != "" else "" + return sql1 + sql2, conf.dbSchema, table_name + elif ds.type == "pg" or ds.type == "excel": + sql1 = """ + SELECT a.attname AS COLUMN_NAME, + pg_catalog.format_type(a.atttypid, a.atttypmod) AS DATA_TYPE, + col_description(c.oid, a.attnum) AS COLUMN_COMMENT + FROM pg_catalog.pg_attribute a + JOIN + pg_catalog.pg_class c ON a.attrelid = c.oid + JOIN + pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE n.nspname = :param1 + AND a.attnum > 0 + AND NOT a.attisdropped \ + """ + sql2 = " AND c.relname = :param2" if table_name is not None and table_name != "" else "" + return sql1 + sql2, conf.dbSchema, table_name + elif ds.type == "redshift": + sql1 = """ SELECT a.attname AS COLUMN_NAME, pg_catalog.format_type(a.atttypid, a.atttypmod) AS DATA_TYPE, col_description(c.oid, a.attnum) AS COLUMN_COMMENT @@ -174,14 +202,14 @@ def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = No pg_catalog.pg_class c ON a.attrelid = c.oid JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace - WHERE n.nspname = '{conf.dbSchema}' + WHERE n.nspname = %s AND a.attnum > 0 AND NOT a.attisdropped \ """ - sql2 = f" AND c.relname = '{table_name}'" if table_name is not None and table_name != "" else "" - return sql1 + sql2 + sql2 = " AND c.relname = %s" if table_name is not None and table_name != "" else "" + return sql1 + sql2, conf.dbSchema, table_name elif ds.type == "oracle": - sql1 = f""" + sql1 = """ SELECT col.COLUMN_NAME AS "COLUMN_NAME", (CASE @@ -201,23 +229,23 @@ def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = No AND col.TABLE_NAME = com.TABLE_NAME AND col.COLUMN_NAME = com.COLUMN_NAME WHERE - col.OWNER = '{conf.dbSchema}' + col.OWNER = :param1 """ - sql2 = f" AND col.TABLE_NAME = '{table_name}'" if table_name is not None and table_name != "" else "" - return sql1 + sql2 + sql2 = " AND col.TABLE_NAME = :param2" if table_name is not None and table_name != "" else "" + return sql1 + sql2, conf.dbSchema, table_name elif ds.type == "ck": - sql1 = f""" + sql1 = """ SELECT name AS COLUMN_NAME, type AS DATA_TYPE, comment AS COLUMN_COMMENT FROM system.columns - WHERE database = '{conf.database}' + WHERE database = :param1 """ - sql2 = f" AND table = '{table_name}'" if table_name is not None and table_name != "" else "" - return sql1 + sql2 + sql2 = " AND table = :param2" if table_name is not None and table_name != "" else "" + return sql1 + sql2, conf.database, table_name elif ds.type == 'dm': - sql1 = f""" + sql1 = """ SELECT c.COLUMN_NAME AS "COLUMN_NAME", c.DATA_TYPE AS "DATA_TYPE", @@ -230,7 +258,22 @@ def get_field_sql(ds: CoreDatasource, conf: DatasourceConf, table_name: str = No AND c.TABLE_NAME = com.TABLE_NAME AND c.COLUMN_NAME = com.COLUMN_NAME WHERE - c.OWNER = '{conf.dbSchema}' + c.OWNER = :param1 + """ + sql2 = " AND c.TABLE_NAME = :param2" if table_name is not None and table_name != "" else "" + return sql1 + sql2, conf.dbSchema, table_name + elif ds.type == "doris": + sql1 = """ + SELECT + COLUMN_NAME, + DATA_TYPE, + COLUMN_COMMENT + FROM + INFORMATION_SCHEMA.COLUMNS + WHERE + TABLE_SCHEMA = %s """ - sql2 = f" AND c.TABLE_NAME = '{table_name}'" if table_name is not None and table_name != "" else "" - return sql1 + sql2 + sql2 = " AND TABLE_NAME = %s" if table_name is not None and table_name != "" else "" + return sql1 + sql2, conf.database, table_name + elif ds.type == "es": + return "", None, None diff --git a/backend/apps/db/es_engine.py b/backend/apps/db/es_engine.py index 05c5797b..f8cf1f2f 100644 --- a/backend/apps/db/es_engine.py +++ b/backend/apps/db/es_engine.py @@ -2,12 +2,13 @@ # Date: 2025/9/9 import json +from base64 import b64encode import requests from elasticsearch import Elasticsearch -from fastapi import HTTPException from apps.datasource.models.datasource import DatasourceConf +from common.error import SingleMessageError def get_es_connect(conf: DatasourceConf): @@ -60,29 +61,57 @@ def get_es_fields(conf: DatasourceConf, table_name: str): return res -def get_es_data(conf: DatasourceConf, sql: str, table_name: str): - r = requests.post(f"{conf.host}/_sql/translate", json={"query": sql}) - if r.json().get('error'): - print(json.dumps(r.json())) +# def get_es_data(conf: DatasourceConf, sql: str, table_name: str): +# r = requests.post(f"{conf.host}/_sql/translate", json={"query": sql}) +# if r.json().get('error'): +# print(json.dumps(r.json())) +# +# es_client = get_es_connect(conf) +# response = es_client.search( +# index=table_name, +# body=json.dumps(r.json()) +# ) +# +# # print(response) +# fields = get_es_fields(conf, table_name) +# res = [] +# for hit in response.get('hits').get('hits'): +# item = [] +# if 'fields' in hit: +# result = hit.get('fields') # {'title': ['Python'], 'age': [30]} +# for field in fields: +# v = result.get(field[0]) +# item.append(v[0]) if v else item.append(None) +# res.append(tuple(item)) +# # print(hit['fields']['title'][0]) +# # elif '_source' in hit: +# # print(hit.get('_source')) +# return res, fields - es_client = get_es_connect(conf) - response = es_client.search( - index=table_name, - body=json.dumps(r.json()) - ) - # print(response) - fields = get_es_fields(conf, table_name) - res = [] - for hit in response.get('hits').get('hits'): - item = [] - if 'fields' in hit: - result = hit.get('fields') # {'title': ['Python'], 'age': [30]} - for field in fields: - v = result.get(field[0]) - item.append(v[0]) if v else item.append(None) - res.append(tuple(item)) - # print(hit['fields']['title'][0]) - # elif '_source' in hit: - # print(hit.get('_source')) - return res, fields +def get_es_data_by_http(conf: DatasourceConf, sql: str): + url = conf.host + while url.endswith('/'): + url = url[:-1] + + host = f'{url}/_sql?format=json' + username = f"{conf.username}" + password = f"{conf.password}" + + credentials = f"{username}:{password}" + encoded_credentials = b64encode(credentials.encode()).decode() + + headers = { + "Content-Type": "application/json", + "Authorization": f"Basic {encoded_credentials}" + } + + response = requests.post(host, data=json.dumps({"query": sql}), headers=headers) + + # print(response.json()) + res = response.json() + if res.get('error'): + raise SingleMessageError(json.dumps(res)) + fields = res.get('columns') + result = res.get('rows') + return result, fields diff --git a/backend/apps/db/type.py b/backend/apps/db/type.py deleted file mode 100644 index 1e48dc65..00000000 --- a/backend/apps/db/type.py +++ /dev/null @@ -1,18 +0,0 @@ -# Author: Junjun -# Date: 2025/5/22 -from typing import Dict - - -def db_type_relation() -> Dict: - return { - "mysql": "MySQL", - "sqlServer": "Microsoft SQL Server", - "pg": "PostgreSQL", - "excel": "Excel/CSV", - "oracle": "Oracle", - "ck": "ClickHouse", - "dm": "达梦", - "doris": "Apache Doris", - "redshift": "AWS Redshift", - "es": "Elasticsearch" - } diff --git a/backend/apps/mcp/mcp.py b/backend/apps/mcp/mcp.py index 15878792..433e4727 100644 --- a/backend/apps/mcp/mcp.py +++ b/backend/apps/mcp/mcp.py @@ -1,6 +1,7 @@ # Author: Junjun # Date: 2025/7/1 - +import json +import traceback from datetime import timedelta import jwt @@ -10,15 +11,17 @@ from jwt.exceptions import InvalidTokenError from pydantic import ValidationError from sqlmodel import select +from starlette.responses import JSONResponse from apps.chat.api.chat import create_chat -from apps.chat.models.chat_model import ChatMcp, CreateChat, ChatStart, McpQuestion +from apps.chat.models.chat_model import ChatMcp, CreateChat, ChatStart, McpQuestion, McpAssistant, ChatQuestion, \ + ChatFinishStep from apps.chat.task.llm import LLMService from apps.system.crud.user import authenticate from apps.system.crud.user import get_db_user from apps.system.models.system_model import UserWsModel from apps.system.models.user import UserModel -from apps.system.schemas.system_schema import BaseUserDTO +from apps.system.schemas.system_schema import BaseUserDTO, AssistantHeader from apps.system.schemas.system_schema import UserInfoDTO from common.core import security from common.core.config import settings @@ -106,8 +109,93 @@ async def mcp_question(session: SessionDep, chat: McpQuestion): raise HTTPException(status_code=400, detail="Inactive user") mcp_chat = ChatMcp(token=chat.token, chat_id=chat.chat_id, question=chat.question) - # ask - llm_service = await LLMService.create(session_user, mcp_chat) - llm_service.init_record() - return StreamingResponse(llm_service.run_task(False), media_type="text/event-stream") + try: + llm_service = await LLMService.create(session_user, mcp_chat) + llm_service.init_record() + llm_service.run_task_async(False, chat.stream) + except Exception as e: + traceback.print_exc() + + if chat.stream: + def _err(_e: Exception): + yield str(_e) + '\n\n' + + return StreamingResponse(_err(e), media_type="text/event-stream") + else: + return JSONResponse( + content={'message': str(e)}, + status_code=500, + ) + if chat.stream: + return StreamingResponse(llm_service.await_result(), media_type="text/event-stream") + else: + res = llm_service.await_result() + raw_data = {} + for chunk in res: + if chunk: + raw_data = chunk + status_code = 200 + if not raw_data.get('success'): + status_code = 500 + + return JSONResponse( + content=raw_data, + status_code=status_code, + ) + + +@router.post("/mcp_assistant", operation_id="mcp_assistant") +async def mcp_assistant(session: SessionDep, chat: McpAssistant): + session_user = BaseUserDTO(**{ + "id": -1, "account": 'sqlbot-mcp-assistant', "oid": 1, "assistant_id": -1, "password": '', "language": "zh-CN" + }) + # session_user: UserModel = get_db_user(session=session, user_id=1) + # session_user.oid = 1 + c = create_chat(session, session_user, CreateChat(origin=1), False) + + # build assistant param + configuration = {"endpoint": chat.url} + # authorization = [{"key": "x-de-token", + # "value": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ1aWQiOjEsIm9pZCI6MSwiZXhwIjoxNzU4NTEyMDA2fQ.3NR-pgnADLdXZtI3dXX5-LuxfGYRvYD9kkr2de7KRP0", + # "target": "header"}] + mcp_assistant_header = AssistantHeader(id=1, name='mcp_assist', domain='', type=1, + configuration=json.dumps(configuration), + certificate=chat.authorization) + + # assistant question + mcp_chat = ChatQuestion(chat_id=c.id, question=chat.question) + # ask + try: + llm_service = await LLMService.create(session_user, mcp_chat, mcp_assistant_header) + llm_service.init_record() + llm_service.run_task_async(False, chat.stream, ChatFinishStep.QUERY_DATA) + except Exception as e: + traceback.print_exc() + + if chat.stream: + def _err(_e: Exception): + yield str(_e) + '\n\n' + + return StreamingResponse(_err(e), media_type="text/event-stream") + else: + return JSONResponse( + content={'message': str(e)}, + status_code=500, + ) + if chat.stream: + return StreamingResponse(llm_service.await_result(), media_type="text/event-stream") + else: + res = llm_service.await_result() + raw_data = {} + for chunk in res: + if chunk: + raw_data = chunk + status_code = 200 + if not raw_data.get('success'): + status_code = 500 + + return JSONResponse( + content=raw_data, + status_code=status_code, + ) diff --git a/backend/apps/system/crud/assistant.py b/backend/apps/system/crud/assistant.py index 01e0c6f0..912218b9 100644 --- a/backend/apps/system/crud/assistant.py +++ b/backend/apps/system/crud/assistant.py @@ -211,7 +211,7 @@ def get_ds_engine(ds: AssistantOutDsSchema) -> Engine: password=ds.password, database=ds.dataBase, driver='', - extraJdbc=ds.extraParams, + extraJdbc=ds.extraParams or '', dbSchema=ds.db_schema or '' ) conf.extraJdbc = '' diff --git a/backend/apps/system/schemas/system_schema.py b/backend/apps/system/schemas/system_schema.py index 79eca6f4..728ca480 100644 --- a/backend/apps/system/schemas/system_schema.py +++ b/backend/apps/system/schemas/system_schema.py @@ -1,8 +1,9 @@ +import re from typing import Optional + from pydantic import BaseModel, Field, field_validator from common.core.schemas import BaseCreatorDTO -import re EMAIL_REGEX = re.compile( r"^[a-zA-Z0-9]+([._-][a-zA-Z0-9]+)*@" @@ -14,106 +15,121 @@ r"(?=.*[~!@#$%^&*()_+\-={}|:\"<>?`\[\];',./])" r"[A-Za-z\d~!@#$%^&*()_+\-={}|:\"<>?`\[\];',./]{8,20}$" ) + + class UserStatus(BaseCreatorDTO): status: int = 1 - - + class UserLanguage(BaseModel): language: str - + class BaseUser(BaseModel): account: str = Field(min_length=1, max_length=100, description="用户账号") oid: int - + + class BaseUserDTO(BaseUser, BaseCreatorDTO): language: str = Field(pattern=r"^(zh-CN|en)$", default="zh-CN", description="用户语言") password: str status: int = 1 + def to_dict(self): return { "id": self.id, "account": self.account, "oid": self.oid } + @field_validator("language") def validate_language(cls, lang: str) -> str: if not re.fullmatch(r"^(zh-CN|en)$", lang): raise ValueError("Language must be 'zh-CN' or 'en'") return lang + class UserCreator(BaseUser): name: str = Field(min_length=1, max_length=100, description="用户名") email: str = Field(min_length=1, max_length=100, description="用户邮箱") status: int = 1 - oid_list: Optional[list[int]] = None - + oid_list: Optional[list[int]] = None + """ @field_validator("email") def validate_email(cls, lang: str) -> str: if not re.fullmatch(EMAIL_REGEX, lang): raise ValueError("Email format is invalid!") return lang """ + class UserEditor(UserCreator, BaseCreatorDTO): - pass + pass + class UserGrid(UserEditor): create_time: int language: str = "zh-CN" - #space_name: Optional[str] = None + # space_name: Optional[str] = None origin: str = '' - - + + class PwdEditor(BaseModel): pwd: str new_pwd: str - + + class UserWsBase(BaseModel): uid_list: list[int] oid: Optional[int] = None + + class UserWsDTO(UserWsBase): weight: Optional[int] = 0 + class UserWsEditor(BaseModel): uid: int - oid: int + oid: int weight: int = 0 + class UserInfoDTO(UserEditor): language: str = "zh-CN" weight: int = 0 isAdmin: bool = False - + class AssistantBase(BaseModel): name: str domain: str - type: int = 0 + type: int = 0 # 0普通小助手 1高级 4页面嵌入 configuration: Optional[str] = None description: Optional[str] = None + + class AssistantDTO(AssistantBase, BaseCreatorDTO): pass + class AssistantHeader(AssistantDTO): unique: Optional[str] = None certificate: Optional[str] = None online: bool = False - + class AssistantValidator(BaseModel): valid: bool = False id_match: bool = False domain_match: bool = False token: Optional[str] = None - + def __init__( - self, - valid: bool = False, - id_match: bool = False, - domain_match: bool = False, - token: Optional[str] = None, - **kwargs + self, + valid: bool = False, + id_match: bool = False, + domain_match: bool = False, + token: Optional[str] = None, + **kwargs ): super().__init__( valid=valid, @@ -122,23 +138,28 @@ def __init__( token=token, **kwargs ) - + + class WorkspaceUser(UserEditor): weight: int create_time: int - + + class UserWs(BaseCreatorDTO): name: str - + + class UserWsOption(UserWs): account: str - - + + class AssistantFieldSchema(BaseModel): id: Optional[int] = None name: Optional[str] = None type: Optional[str] = None comment: Optional[str] = None + + class AssistantTableSchema(BaseModel): id: Optional[int] = None name: Optional[str] = None @@ -147,6 +168,7 @@ class AssistantTableSchema(BaseModel): sql: Optional[str] = None fields: Optional[list[AssistantFieldSchema]] = None + class AssistantOutDsBase(BaseModel): id: Optional[int] = None name: str @@ -154,8 +176,8 @@ class AssistantOutDsBase(BaseModel): type_name: Optional[str] = None comment: Optional[str] = None description: Optional[str] = None - - + + class AssistantOutDsSchema(AssistantOutDsBase): host: Optional[str] = None port: Optional[int] = None @@ -165,7 +187,8 @@ class AssistantOutDsSchema(AssistantOutDsBase): db_schema: Optional[str] = None extraParams: Optional[str] = None tables: Optional[list[AssistantTableSchema]] = None - + + class AssistantUiSchema(BaseCreatorDTO): theme: Optional[str] = None header_font_color: Optional[str] = None @@ -179,7 +202,3 @@ class AssistantUiSchema(BaseCreatorDTO): name: Optional[str] = None welcome: Optional[str] = None welcome_desc: Optional[str] = None - - - - \ No newline at end of file diff --git a/backend/apps/template/generate_chart/generator.py b/backend/apps/template/generate_chart/generator.py index 2c886277..a519b893 100644 --- a/backend/apps/template/generate_chart/generator.py +++ b/backend/apps/template/generate_chart/generator.py @@ -8,3 +8,7 @@ def get_chart_template(): def get_base_terminology_template(): template = get_base_template() return template['template']['terminology'] + +def get_base_data_training_template(): + template = get_base_template() + return template['template']['data_training'] diff --git a/backend/apps/terminology/curd/terminology.py b/backend/apps/terminology/curd/terminology.py index 309a2c82..355f483e 100644 --- a/backend/apps/terminology/curd/terminology.py +++ b/backend/apps/terminology/curd/terminology.py @@ -1,23 +1,23 @@ import datetime import logging import traceback -from concurrent.futures import ThreadPoolExecutor -from typing import List, Optional +from typing import List, Optional, Any from xml.dom.minidom import parseString import dicttoxml +from sqlalchemy import BigInteger from sqlalchemy import and_, or_, select, func, delete, update, union -from sqlalchemy import create_engine, text +from sqlalchemy import text from sqlalchemy.orm import aliased -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm.session import Session from apps.ai_model.embedding import EmbeddingModelCache +from apps.datasource.models.datasource import CoreDatasource from apps.template.generate_chart.generator import get_base_terminology_template from apps.terminology.models.terminology_model import Terminology, TerminologyInfo from common.core.config import settings from common.core.deps import SessionDep, Trans - -executor = ThreadPoolExecutor(max_workers=200) +from common.utils.embedding_threads import run_save_terminology_embeddings def page_terminology(session: SessionDep, current_page: int = 1, page_size: int = 10, name: Optional[str] = None, @@ -82,6 +82,16 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int .subquery() ) + # 创建子查询来获取数据源名称,添加类型转换 + datasource_names_subquery = ( + select( + func.jsonb_array_elements(Terminology.datasource_ids).cast(BigInteger).label('ds_id'), + Terminology.id.label('term_id') + ) + .where(Terminology.id.in_(paginated_parent_ids)) + .subquery() + ) + # 主查询 stmt = ( select( @@ -89,13 +99,34 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int Terminology.word, Terminology.create_time, Terminology.description, - children_subquery.c.other_words + Terminology.specific_ds, + Terminology.datasource_ids, + children_subquery.c.other_words, + func.jsonb_agg(CoreDatasource.name).filter(CoreDatasource.id.isnot(None)).label('datasource_names') ) .outerjoin( children_subquery, Terminology.id == children_subquery.c.pid ) + # 关联数据源名称子查询和 CoreDatasource 表 + .outerjoin( + datasource_names_subquery, + datasource_names_subquery.c.term_id == Terminology.id + ) + .outerjoin( + CoreDatasource, + CoreDatasource.id == datasource_names_subquery.c.ds_id + ) .where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid)) + .group_by( + Terminology.id, + Terminology.word, + Terminology.create_time, + Terminology.description, + Terminology.specific_ds, + Terminology.datasource_ids, + children_subquery.c.other_words + ) .order_by(Terminology.create_time.desc()) ) else: @@ -118,17 +149,59 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int .subquery() ) + children_subquery = ( + select( + child.pid, + func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words') + ) + .where(child.pid.isnot(None)) + .group_by(child.pid) + .subquery() + ) + + # 创建子查询来获取数据源名称 + datasource_names_subquery = ( + select( + func.jsonb_array_elements(Terminology.datasource_ids).cast(BigInteger).label('ds_id'), + Terminology.id.label('term_id') + ) + .where(Terminology.id.in_(paginated_parent_ids)) + .subquery() + ) + stmt = ( select( Terminology.id, Terminology.word, Terminology.create_time, Terminology.description, - func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words') + Terminology.specific_ds, + Terminology.datasource_ids, + children_subquery.c.other_words, + func.jsonb_agg(CoreDatasource.name).filter(CoreDatasource.id.isnot(None)).label('datasource_names') + ) + .outerjoin( + children_subquery, + Terminology.id == children_subquery.c.pid + ) + # 关联数据源名称子查询和 CoreDatasource 表 + .outerjoin( + datasource_names_subquery, + datasource_names_subquery.c.term_id == Terminology.id + ) + .outerjoin( + CoreDatasource, + CoreDatasource.id == datasource_names_subquery.c.ds_id ) - .outerjoin(child, and_(Terminology.id == child.pid)) .where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid)) - .group_by(Terminology.id, Terminology.word) + .group_by(Terminology.id, + Terminology.word, + Terminology.create_time, + Terminology.description, + Terminology.specific_ds, + Terminology.datasource_ids, + children_subquery.c.other_words + ) .order_by(Terminology.create_time.desc()) ) @@ -141,6 +214,9 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int create_time=row.create_time, description=row.description, other_words=row.other_words if row.other_words else [], + specific_ds=row.specific_ds if row.specific_ds is not None else False, + datasource_ids=row.datasource_ids if row.datasource_ids is not None else [], + datasource_names=row.datasource_names if row.datasource_names is not None else [], )) return current_page, page_size, total_count, total_pages, _list @@ -148,7 +224,13 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans): create_time = datetime.datetime.now() - parent = Terminology(word=info.word, create_time=create_time, description=info.description, oid=oid) + + specific_ds = info.specific_ds if info.specific_ds is not None else False + datasource_ids = info.datasource_ids if info.datasource_ids is not None else [] + + parent = Terminology(word=info.word, create_time=create_time, description=info.description, oid=oid, + specific_ds=specific_ds, + datasource_ids=datasource_ids) words = [info.word] for child in info.other_words: @@ -177,13 +259,14 @@ def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra if other_word.strip() == "": continue _list.append( - Terminology(pid=result.id, word=other_word, create_time=create_time, oid=oid)) + Terminology(pid=result.id, word=other_word, create_time=create_time, oid=oid, + specific_ds=specific_ds, datasource_ids=datasource_ids)) session.bulk_save_objects(_list) session.flush() session.commit() # embedding - run_save_embeddings([result.id]) + run_save_terminology_embeddings([result.id]) return result.id @@ -216,9 +299,14 @@ def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra if exists: raise Exception(trans("i18n_terminology.exists_in_db")) + specific_ds = info.specific_ds if info.specific_ds is not None else False + datasource_ids = info.datasource_ids if info.datasource_ids is not None else [] + stmt = update(Terminology).where(and_(Terminology.id == info.id)).values( word=info.word, description=info.description, + specific_ds=specific_ds, + datasource_ids=datasource_ids ) session.execute(stmt) session.commit() @@ -234,13 +322,14 @@ def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra if other_word.strip() == "": continue _list.append( - Terminology(pid=info.id, word=other_word, create_time=create_time, oid=oid)) + Terminology(pid=info.id, word=other_word, create_time=create_time, oid=oid, + specific_ds=specific_ds, datasource_ids=datasource_ids)) session.bulk_save_objects(_list) session.flush() session.commit() # embedding - run_save_embeddings([info.id]) + run_save_terminology_embeddings([info.id]) return info.id @@ -251,37 +340,31 @@ def delete_terminology(session: SessionDep, ids: list[int]): session.commit() -def run_save_embeddings(ids: List[int]): - executor.submit(save_embeddings, ids) - - -def fill_empty_embeddings(): - executor.submit(run_fill_empty_embeddings) +# def run_save_embeddings(ids: List[int]): +# executor.submit(save_embeddings, ids) +# +# +# def fill_empty_embeddings(): +# executor.submit(run_fill_empty_embeddings) -def run_fill_empty_embeddings(): +def run_fill_empty_embeddings(session: Session): if not settings.EMBEDDING_ENABLED: return - engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI)) - session_maker = sessionmaker(bind=engine) - session = session_maker() stmt1 = select(Terminology.id).where(and_(Terminology.embedding.is_(None), Terminology.pid.is_(None))) stmt2 = select(Terminology.pid).where(and_(Terminology.embedding.is_(None), Terminology.pid.isnot(None))).distinct() combined_stmt = union(stmt1, stmt2) results = session.execute(combined_stmt).scalars().all() - save_embeddings(results) + save_embeddings(session, results) -def save_embeddings(ids: List[int]): +def save_embeddings(session: Session, ids: List[int]): if not settings.EMBEDDING_ENABLED: return if not ids or len(ids) == 0: return try: - engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI)) - session_maker = sessionmaker(bind=engine) - session = session_maker() _list = session.query(Terminology).filter(or_(Terminology.id.in_(ids), Terminology.pid.in_(ids))).all() @@ -304,17 +387,22 @@ def save_embeddings(ids: List[int]): embedding_sql = f""" SELECT id, pid, word, similarity FROM -(SELECT id, pid, word, oid, +(SELECT id, pid, word, oid, specific_ds, datasource_ids, ( 1 - (embedding <=> :embedding_array) ) AS similarity FROM terminology AS child ) TEMP -WHERE similarity > {settings.EMBEDDING_SIMILARITY} and oid = :oid +WHERE similarity > {settings.EMBEDDING_TERMINOLOGY_SIMILARITY} AND oid = :oid +AND ( + (:datasource IS NULL AND (specific_ds = false OR specific_ds IS NULL)) + OR + (:datasource IS NOT NULL AND ((specific_ds = false OR specific_ds IS NULL) OR (specific_ds = true AND datasource_ids IS NOT NULL AND datasource_ids @> jsonb_build_array(:datasource)))) +) ORDER BY similarity DESC -LIMIT {settings.EMBEDDING_TOP_COUNT} +LIMIT {settings.EMBEDDING_TERMINOLOGY_TOP_COUNT} """ -def select_terminology_by_word(session: SessionDep, word: str, oid: int): +def select_terminology_by_word(session: SessionDep, word: str, oid: int, datasource: int = None): if word.strip() == "": return [] @@ -331,7 +419,26 @@ def select_terminology_by_word(session: SessionDep, word: str, oid: int): ) ) - results = session.execute(stmt, {'sentence': word}).fetchall() + if datasource is not None: + stmt = stmt.where( + or_( + or_(Terminology.specific_ds == False, Terminology.specific_ds.is_(None)), + and_( + Terminology.specific_ds == True, + Terminology.datasource_ids.isnot(None), + text("datasource_ids @> jsonb_build_array(:datasource)") + ) + ) + ) + else: + stmt = stmt.where(or_(Terminology.specific_ds == False, Terminology.specific_ds.is_(None))) + + # 执行查询 + params: dict[str, Any] = {'sentence': word} + if datasource is not None: + params['datasource'] = datasource + + results = session.execute(stmt, params).fetchall() for row in results: _list.append(Terminology(id=row.id, word=row.word, pid=row.pid)) @@ -342,7 +449,8 @@ def select_terminology_by_word(session: SessionDep, word: str, oid: int): embedding = model.embed_query(word) - results = session.execute(text(embedding_sql), {'embedding_array': str(embedding), 'oid': oid}) + results = session.execute(text(embedding_sql), {'embedding_array': str(embedding), 'oid': oid, + 'datasource': datasource}).fetchall() for row in results: _list.append(Terminology(id=row.id, word=row.word, pid=row.pid)) @@ -418,10 +526,11 @@ def to_xml_string(_dict: list[dict] | dict, root: str = 'terminologies') -> str: return pretty_xml -def get_terminology_template(session: SessionDep, question: str, oid: Optional[int] = 1) -> str: +def get_terminology_template(session: SessionDep, question: str, oid: Optional[int] = 1, + datasource: Optional[int] = None) -> str: if not oid: oid = 1 - _results = select_terminology_by_word(session, question, oid) + _results = select_terminology_by_word(session, question, oid, datasource) if _results and len(_results) > 0: terminology = to_xml_string(_results) template = get_base_terminology_template().format(terminologies=terminology) diff --git a/backend/apps/terminology/models/terminology_model.py b/backend/apps/terminology/models/terminology_model.py index 57c35ce4..b9048659 100644 --- a/backend/apps/terminology/models/terminology_model.py +++ b/backend/apps/terminology/models/terminology_model.py @@ -3,7 +3,8 @@ from pgvector.sqlalchemy import VECTOR from pydantic import BaseModel -from sqlalchemy import Column, Text, BigInteger, DateTime, Identity +from sqlalchemy import Column, Text, BigInteger, DateTime, Identity, Boolean +from sqlalchemy.dialects.postgresql import JSONB from sqlmodel import SQLModel, Field @@ -16,6 +17,8 @@ class Terminology(SQLModel, table=True): word: Optional[str] = Field(max_length=255) description: Optional[str] = Field(sa_column=Column(Text, nullable=True)) embedding: Optional[List[float]] = Field(sa_column=Column(VECTOR(), nullable=True)) + specific_ds: Optional[bool] = Field(sa_column=Column(Boolean, default=False)) + datasource_ids: Optional[list[int]] = Field(sa_column=Column(JSONB), default=[]) class TerminologyInfo(BaseModel): @@ -24,5 +27,6 @@ class TerminologyInfo(BaseModel): word: Optional[str] = None description: Optional[str] = None other_words: Optional[List[str]] = [] - - + specific_ds: Optional[bool] = False + datasource_ids: Optional[list[int]] = [] + datasource_names: Optional[list[str]] = [] diff --git a/backend/common/core/config.py b/backend/common/core/config.py index 6a922ae9..ca0468b8 100644 --- a/backend/common/core/config.py +++ b/backend/common/core/config.py @@ -89,13 +89,24 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str: LOCAL_MODEL_PATH: str = '/opt/sqlbot/models' DEFAULT_EMBEDDING_MODEL: str = 'shibing624/text2vec-base-chinese' EMBEDDING_ENABLED: bool = True - EMBEDDING_SIMILARITY: float = 0.4 - EMBEDDING_TOP_COUNT: int = 5 + EMBEDDING_DEFAULT_SIMILARITY: float = 0.4 + EMBEDDING_TERMINOLOGY_SIMILARITY: float = EMBEDDING_DEFAULT_SIMILARITY + EMBEDDING_DATA_TRAINING_SIMILARITY: float = EMBEDDING_DEFAULT_SIMILARITY + EMBEDDING_DEFAULT_TOP_COUNT: int = 5 + EMBEDDING_TERMINOLOGY_TOP_COUNT: int = EMBEDDING_DEFAULT_TOP_COUNT + EMBEDDING_DATA_TRAINING_TOP_COUNT: int = EMBEDDING_DEFAULT_TOP_COUNT + + PARSE_REASONING_BLOCK_ENABLED: bool = True + DEFAULT_REASONING_CONTENT_START: str = '' + DEFAULT_REASONING_CONTENT_END: str = '' PG_POOL_SIZE: int = 20 PG_MAX_OVERFLOW: int = 30 PG_POOL_RECYCLE: int = 3600 PG_POOL_PRE_PING: bool = True + TABLE_EMBEDDING_ENABLED: bool = False + TABLE_EMBEDDING_COUNT: int = 10 + settings = Settings() # type: ignore diff --git a/backend/common/core/response_middleware.py b/backend/common/core/response_middleware.py index 6842893e..c60a959d 100644 --- a/backend/common/core/response_middleware.py +++ b/backend/common/core/response_middleware.py @@ -1,18 +1,33 @@ import json -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import JSONResponse + from starlette.exceptions import HTTPException +from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request +from starlette.responses import JSONResponse + from common.core.config import settings from common.utils.utils import SQLBotLogUtil + + class ResponseMiddleware(BaseHTTPMiddleware): def __init__(self, app): super().__init__(app) async def dispatch(self, request, call_next): response = await call_next(request) - - if isinstance(response, JSONResponse) or request.url.path == f"{settings.API_V1_STR}/openapi.json": + + direct_paths = [ + f"{settings.API_V1_STR}/mcp/mcp_question", + f"{settings.API_V1_STR}/mcp/mcp_assistant" + ] + + route = request.scope.get("route") + # 获取定义的路径模式,例如 '/items/{item_id}' + path_pattern = '' if not route else route.path_format + + if (isinstance(response, JSONResponse) + or request.url.path == f"{settings.API_V1_STR}/openapi.json" + or path_pattern in direct_paths): return response if response.status_code != 200: return response @@ -21,9 +36,9 @@ async def dispatch(self, request, call_next): body = b"" async for chunk in response.body_iterator: body += chunk - + raw_data = json.loads(body.decode()) - + if isinstance(raw_data, dict) and all(k in raw_data for k in ["code", "data", "msg"]): return JSONResponse( content=raw_data, @@ -33,13 +48,13 @@ async def dispatch(self, request, call_next): if k.lower() not in ("content-length", "content-type") } ) - + wrapped_data = { "code": 0, "data": raw_data, "msg": None } - + return JSONResponse( content=wrapped_data, status_code=response.status_code, @@ -58,7 +73,7 @@ async def dispatch(self, request, call_next): if k.lower() not in ("content-length", "content-type") } ) - + return response @@ -72,7 +87,6 @@ async def http_exception_handler(request: Request, exc: HTTPException): headers={"Access-Control-Allow-Origin": "*"} ) - @staticmethod async def global_exception_handler(request: Request, exc: Exception): SQLBotLogUtil.error(f"Unhandled Exception: {str(exc)}", exc_info=True) @@ -81,4 +95,3 @@ async def global_exception_handler(request: Request, exc: Exception): content=str(exc), headers={"Access-Control-Allow-Origin": "*"} ) - diff --git a/backend/common/utils/embedding_threads.py b/backend/common/utils/embedding_threads.py new file mode 100644 index 00000000..a38b66f0 --- /dev/null +++ b/backend/common/utils/embedding_threads.py @@ -0,0 +1,33 @@ +from concurrent.futures import ThreadPoolExecutor +from typing import List + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from common.core.config import settings + +executor = ThreadPoolExecutor(max_workers=200) + +engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI)) +session_maker = sessionmaker(bind=engine) +session = session_maker() + + +def run_save_terminology_embeddings(ids: List[int]): + from apps.terminology.curd.terminology import save_embeddings + executor.submit(save_embeddings, session, ids) + + +def fill_empty_terminology_embeddings(): + from apps.terminology.curd.terminology import run_fill_empty_embeddings + executor.submit(run_fill_empty_embeddings, session) + + +def run_save_data_training_embeddings(ids: List[int]): + from apps.data_training.curd.data_training import save_embeddings + executor.submit(save_embeddings, session, ids) + + +def fill_empty_data_training_embeddings(): + from apps.data_training.curd.data_training import run_fill_empty_embeddings + executor.submit(run_fill_empty_embeddings, session) diff --git a/backend/locales/en.json b/backend/locales/en.json index 36cb7f95..df427789 100644 --- a/backend/locales/en.json +++ b/backend/locales/en.json @@ -42,5 +42,13 @@ "terminology_not_exists": "Terminology does not exists", "cannot_be_repeated": "Term name, synonyms cannot be repeated", "exists_in_db": "Term name, synonyms exists" + }, + "i18n_data_training": { + "datasource_cannot_be_none": "Datasource cannot be none", + "data_training_not_exists": "Example does not exists", + "exists_in_db": "Question exists" + }, + "i18n_excel_export": { + "data_is_empty": "The form data is empty, cannot export data" } } \ No newline at end of file diff --git a/backend/locales/zh-CN.json b/backend/locales/zh-CN.json index c8c67e94..63961877 100644 --- a/backend/locales/zh-CN.json +++ b/backend/locales/zh-CN.json @@ -41,6 +41,14 @@ "i18n_terminology": { "terminology_not_exists": "该术语不存在", "cannot_be_repeated": "术语名称,同义词不能重复", - "exists_in_db": "术语名称,同义词已存在" + "exists_in_db": "术语名称,同义词已存在" + }, + "i18n_data_training": { + "datasource_cannot_be_none": "数据源不能为空", + "data_training_not_exists": "该示例不存在", + "exists_in_db": "该问题已存在" + }, + "i18n_excel_export": { + "data_is_empty": "表单数据为空,无法导出数据" } } \ No newline at end of file diff --git a/backend/main.py b/backend/main.py index 8912ef1e..34e1ad67 100644 --- a/backend/main.py +++ b/backend/main.py @@ -15,10 +15,10 @@ from apps.system.crud.aimodel_manage import async_model_info from apps.system.crud.assistant import init_dynamic_cors from apps.system.middleware.auth import TokenMiddleware -from apps.terminology.curd.terminology import fill_empty_embeddings from common.core.config import settings from common.core.response_middleware import ResponseMiddleware, exception_handler from common.core.sqlbot_cache import init_sqlbot_cache +from common.utils.embedding_threads import fill_empty_terminology_embeddings, fill_empty_data_training_embeddings from common.utils.utils import SQLBotLogUtil @@ -27,8 +27,12 @@ def run_migrations(): command.upgrade(alembic_cfg, "head") -def init_embedding_data(): - fill_empty_embeddings() +def init_terminology_embedding_data(): + fill_empty_terminology_embeddings() + + +def init_data_training_embedding_data(): + fill_empty_data_training_embeddings() @asynccontextmanager @@ -36,7 +40,8 @@ async def lifespan(app: FastAPI): run_migrations() init_sqlbot_cache() init_dynamic_cors(app) - init_embedding_data() + init_terminology_embedding_data() + init_data_training_embedding_data() SQLBotLogUtil.info("✅ SQLBot 初始化完成") await sqlbot_xpack.core.clean_xpack_cache() await async_model_info() # 异步加密已有模型的密钥和地址 @@ -68,7 +73,7 @@ def custom_generate_unique_id(route: APIRoute) -> str: description="SQLBot MCP Server", describe_all_responses=True, describe_full_response_schema=True, - include_operations=["get_datasource_list", "get_model_list", "mcp_question", "mcp_start"] + include_operations=["get_datasource_list", "get_model_list", "mcp_question", "mcp_start", "mcp_assistant"] ) mcp.mount(mcp_app) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index a1c5b4e3..8ccdd585 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "sqlbot" -version = "1.1.3" +version = "1.1.4" description = "" requires-python = "==3.11.*" dependencies = [ diff --git a/backend/template.yaml b/backend/template.yaml index 49f07c1e..a9b8d7e4 100644 --- a/backend/template.yaml +++ b/backend/template.yaml @@ -2,17 +2,20 @@ template: terminology: | {terminologies} - + data_training: | + + {data_training} sql: system: | 你是"SQLBOT",智能问数小助手,可以根据用户提问,专业生成SQL与可视化图表。 你当前的任务是根据给定的表结构和用户问题生成SQL语句、可能适合展示的图表类型以及该SQL中所用到的表名。 - 我们会在块内提供给你信息,帮助你生成SQL: - 内有等信息; + 我们会在块内提供给你信息,帮助你生成SQL: + 内有等信息; 其中,:提供数据库引擎及版本信息; :以 M-Schema 格式提供数据库表结构信息; :提供一组术语,块内每一个就是术语,其中同一个内的多个代表术语的多种叫法,也就是术语与它的同义词,即该术语对应的描述,其中也可能是能够用来参考的计算公式,或者是一些其他的查询条件 + :提供一组SQL示例,你可以参考这些示例来生成你的回答,其中内是提问,内是对于该提问的解释或者对应应该回答的SQL示例 用户的提问在内,内则会提供上次执行你提供的SQL时会出现的错误信息,内的会告诉你用户当前提问的时间 @@ -224,6 +227,7 @@ template: {terminologies} + {data_training} ### 响应, 请根据上述要求直接返回JSON结果: @@ -385,12 +389,25 @@ template: {old_questions} analysis: system: | - ### 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 + + 你是"SQLBOT",智能问数小助手,可以根据用户提问,专业生成SQL与可视化图表。 + 你当前的任务是根据给定的数据分析数据,并给出你的分析结果。 + 我们会在块内提供给你信息,帮助你进行分析: + 内有等信息; + :提供一组术语,块内每一个就是术语,其中同一个内的多个代表术语的多种叫法,也就是术语与它的同义词,即该术语对应的描述,其中也可能是能够用来参考的计算公式,或者是一些其他的查询条件 + - ### 说明: - 你是一个数据分析师,你的任务是根据给定的数据分析数据,并给出你的分析结果。 + 你必须遵守以下规则: + + + 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 + + + ### 下面是提供的信息 + {terminologies} + user: | ### 字段(字段别名): {fields} diff --git a/docker-compose.yaml b/docker-compose.yaml index 6f54ca75..5cd98599 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,8 +1,9 @@ services: sqlbot: - image: dataease/sqlbot:v1.1.0 + image: dataease/sqlbot container_name: sqlbot restart: always + privileged: true networks: - sqlbot-network ports: @@ -29,8 +30,9 @@ services: SQL_DEBUG: False volumes: - ./data/sqlbot/excel:/opt/sqlbot/data/excel + - ./data/sqlbot/file:/opt/sqlbot/data/file - ./data/sqlbot/images:/opt/sqlbot/images - - ./data/sqlbot/logs:/opt/sqlbot/logs + - ./data/sqlbot/logs:/opt/sqlbot/app/logs - ./data/postgresql:/var/lib/postgresql/data networks: diff --git a/frontend/package.json b/frontend/package.json index c06d2008..50e4dbce 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -18,6 +18,7 @@ "dependencies": { "@antv/g2": "^5.3.3", "@antv/s2": "^2.4.3", + "@antv/x6": "^2.18.1", "@eslint/js": "^9.28.0", "@highlightjs/vue-plugin": "^2.1.0", "@npkg/tinymce-plugins": "^0.0.7", diff --git a/frontend/public/assistant.js b/frontend/public/assistant.js index 382d9954..38a963b2 100644 --- a/frontend/public/assistant.js +++ b/frontend/public/assistant.js @@ -69,7 +69,7 @@ const getChatContainerHtml = (data) => { return `
- +
@@ -498,6 +498,7 @@ function loadScript(src, id) { const domain_url = getDomain(src) const online = getParam(src, 'online') + const userFlag = getParam(src, 'userFlag') let url = `${domain_url}/api/v1/system/assistant/info/${id}` if (domain_url.includes('5173')) { url = url.replace('5173', '8000') @@ -534,6 +535,7 @@ } tempData['online'] = online && online.toString().toLowerCase() == 'true' + tempData['userFlag'] = userFlag initsqlbot_assistant(tempData) if (data.type == 1) { registerMessageEvent(id, tempData) @@ -708,7 +710,7 @@ contentWindow.postMessage(params, url) } } - window.sqlbot_assistant_handler[id]['refresh'] = (online) => { + window.sqlbot_assistant_handler[id]['refresh'] = (online, userFlag) => { if (online != null && typeof online != 'boolean') { throw new Error('The parameter can only be of type boolean') } @@ -719,12 +721,35 @@ if (online != null) { new_url = updateParam(new_url, 'online', online) } + if (userFlag != null) { + new_url = updateParam(new_url, 'userFlag', userFlag) + } iframe.src = 'about:blank' setTimeout(() => { iframe.src = new_url }, 500) } } + window.sqlbot_assistant_handler[id]['destroy'] = () => { + const sqlbot_root_id = 'sqlbot-assistant-root-' + id + const container_div = document.getElementById(sqlbot_root_id) + if (container_div) { + const root_div = container_div.parentNode + if (root_div?.parentNode) { + root_div.parentNode.removeChild(root_div) + } + } + + const scriptDom = document.getElementById(`sqlbot-assistant-float-script-${id}`) + if (scriptDom) { + scriptDom.parentNode.removeChild(scriptDom) + } + const propName = script_id_prefix + id + '-state' + if (window[propName]) { + delete window[propName] + } + delete window.sqlbot_assistant_handler[id] + } } // window.addEventListener('load', init) const executeWhenReady = (fn) => { diff --git a/frontend/src/api/chat.ts b/frontend/src/api/chat.ts index c1b526c7..76c436dd 100644 --- a/frontend/src/api/chat.ts +++ b/frontend/src/api/chat.ts @@ -332,5 +332,9 @@ export const chatApi = { return request.fetchStream(`/chat/recommend_questions/${record_id}`, {}, controller) }, checkLLMModel: () => request.get('/system/aimodel/default', { requestOptions: { silent: true } }), - export2Excel: (data: any) => request.post('/chat/excel/export', data, { responseType: 'blob' }), + export2Excel: (data: any) => + request.post('/chat/excel/export', data, { + responseType: 'blob', + requestOptions: { customError: true }, + }), } diff --git a/frontend/src/api/datasource.ts b/frontend/src/api/datasource.ts index 3274b3f4..4a733185 100644 --- a/frontend/src/api/datasource.ts +++ b/frontend/src/api/datasource.ts @@ -3,6 +3,8 @@ import { request } from '@/utils/request' export const datasourceApi = { check: (data: any) => request.post('/datasource/check', data), check_by_id: (id: any) => request.get(`/datasource/check/${id}`), + relationGet: (id: any) => request.post(`/table_relation/get/${id}`), + relationSave: (dsId: any, data: any) => request.post(`/table_relation/save/${dsId}`, data), add: (data: any) => request.post('/datasource/add', data), list: () => request.get('/datasource/list'), update: (data: any) => request.post('/datasource/update', data), diff --git a/frontend/src/api/training.ts b/frontend/src/api/training.ts new file mode 100644 index 00000000..c88ad13a --- /dev/null +++ b/frontend/src/api/training.ts @@ -0,0 +1,11 @@ +import { request } from '@/utils/request' + +export const trainingApi = { + getList: (pageNum: any, pageSize: any, params: any) => + request.get(`/system/data-training/page/${pageNum}/${pageSize}`, { + params, + }), + updateEmbedded: (data: any) => request.put('/system/data-training', data), + deleteEmbedded: (params: any) => request.delete('/system/data-training', { data: params }), + getOne: (id: any) => request.get(`/system/data-training/${id}`), +} diff --git a/frontend/src/assets/svg/401.svg b/frontend/src/assets/svg/401.svg new file mode 100644 index 00000000..d02779fa --- /dev/null +++ b/frontend/src/assets/svg/401.svg @@ -0,0 +1,29 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/frontend/src/assets/svg/icon_mindnote_outlined.svg b/frontend/src/assets/svg/icon_mindnote_outlined.svg new file mode 100644 index 00000000..d3456f2b --- /dev/null +++ b/frontend/src/assets/svg/icon_mindnote_outlined.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/frontend/src/components/layout/LayoutDsl.vue b/frontend/src/components/layout/LayoutDsl.vue index f17b8129..1c17fcd6 100644 --- a/frontend/src/components/layout/LayoutDsl.vue +++ b/frontend/src/components/layout/LayoutDsl.vue @@ -46,6 +46,10 @@ const toWorkspace = () => { const toChatIndex = () => { router.push('/chat/index') } + +const toUserIndex = () => { + router.push('/system/user') +} const route = useRoute() const showSysmenu = computed(() => { return route.path.includes('/system') @@ -55,43 +59,60 @@ const showSysmenu = computed(() => {
- -